diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index e49ab907d71..9eef1e81d33 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -3085,7 +3085,9 @@ namespace ts { // Use the type of the initializer expression if one is present if (declaration.initializer) { - return addOptionality(checkExpressionCached(declaration.initializer), /*optional*/ declaration.questionToken && includeOptionality); + const exprType = checkExpressionCached(declaration.initializer); + const type = getCombinedNodeFlags(declaration) & NodeFlags.Const ? exprType : getBaseTypeOfLiteralType(exprType); + return addOptionality(type, /*optional*/ declaration.questionToken && includeOptionality); } // If it is a short-hand property assignment, use the type of the identifier @@ -3384,7 +3386,7 @@ namespace ts { function getTypeOfEnumMember(symbol: Symbol): Type { const links = getSymbolLinks(symbol); if (!links.type) { - links.type = getDeclaredTypeOfEnum(getParentOfSymbol(symbol)); + links.type = getDeclaredTypeOfEnumMember(symbol); } return links.type; } @@ -7201,18 +7203,18 @@ namespace ts { return (type.flags & (TypeFlags.Literal | TypeFlags.Undefined | TypeFlags.Null)) !== 0; } - function isUnitUnionType(type: Type): boolean { + function isLiteralType(type: Type): boolean { return type.flags & TypeFlags.Boolean ? true : type.flags & TypeFlags.Union ? type.flags & TypeFlags.Enum ? true : !forEach((type).types, t => !isUnitType(t)) : isUnitType(type); } - function getBaseTypeOfUnitType(type: Type): Type { + function getBaseTypeOfLiteralType(type: Type): Type { return type.flags & TypeFlags.StringLiteral ? stringType : type.flags & TypeFlags.NumberLiteral ? numberType : type.flags & TypeFlags.BooleanLiteral ? booleanType : type.flags & TypeFlags.EnumLiteral ? (type).baseType : - type.flags & TypeFlags.Union && !(type.flags & TypeFlags.Enum) ? getUnionType(map((type).types, getBaseTypeOfUnitType)) : + type.flags & TypeFlags.Union && !(type.flags & TypeFlags.Enum) ? getUnionType(map((type).types, getBaseTypeOfLiteralType)) : type; } @@ -7576,8 +7578,9 @@ namespace ts { const candidates = inferiority ? inferences.secondary || (inferences.secondary = []) : inferences.primary || (inferences.primary = []); - if (!contains(candidates, source)) { - candidates.push(source); + const widened = isUnitType(source) ? getBaseTypeOfLiteralType(source): source; + if (!contains(candidates, widened)) { + candidates.push(widened); } } return; @@ -7904,7 +7907,7 @@ namespace ts { if (prop && prop.flags & SymbolFlags.SyntheticProperty) { if ((prop).isDiscriminantProperty === undefined) { (prop).isDiscriminantProperty = !(prop).hasCommonType && - isUnitUnionType(getTypeOfSymbol(prop)); + isLiteralType(getTypeOfSymbol(prop)); } return (prop).isDiscriminantProperty; } @@ -9757,6 +9760,7 @@ namespace ts { case SyntaxKind.BinaryExpression: return getContextualTypeForBinaryOperand(node); case SyntaxKind.PropertyAssignment: + case SyntaxKind.ShorthandPropertyAssignment: return getContextualTypeForObjectLiteralElement(parent); case SyntaxKind.ArrayLiteralExpression: return getContextualTypeForElementExpression(node); @@ -9776,31 +9780,6 @@ namespace ts { return undefined; } - function isLiteralTypeLocation(node: Node): boolean { - const parent = node.parent; - switch (parent.kind) { - case SyntaxKind.BinaryExpression: - switch ((parent).operatorToken.kind) { - case SyntaxKind.EqualsEqualsEqualsToken: - case SyntaxKind.ExclamationEqualsEqualsToken: - case SyntaxKind.EqualsEqualsToken: - case SyntaxKind.ExclamationEqualsToken: - return true; - } - break; - case SyntaxKind.ConditionalExpression: - return (node === (parent).whenTrue || - node === (parent).whenFalse) && - isLiteralTypeLocation(parent); - case SyntaxKind.ParenthesizedExpression: - return isLiteralTypeLocation(parent); - case SyntaxKind.CaseClause: - case SyntaxKind.LiteralType: - return true; - } - return false; - } - // If the given type is an object or union type, if that type has a single signature, and if // that signature is non-generic, return the signature. Otherwise return undefined. function getNonGenericSignature(type: Type): Signature { @@ -9937,7 +9916,7 @@ namespace ts { } } else { - const type = checkExpression(e, contextualMapper); + const type = checkExpressionForMutableLocation(e, contextualMapper); elementTypes.push(type); } hasSpreadElement = hasSpreadElement || e.kind === SyntaxKind.SpreadElementExpression; @@ -10077,7 +10056,7 @@ namespace ts { } else { Debug.assert(memberDecl.kind === SyntaxKind.ShorthandPropertyAssignment); - type = checkExpression((memberDecl).name, contextualMapper); + type = checkExpressionForMutableLocation((memberDecl).name, contextualMapper); } typeFlags |= type.flags; const prop = createSymbol(SymbolFlags.Property | SymbolFlags.Transient | member.flags, member.name); @@ -10799,9 +10778,6 @@ namespace ts { } let propType = getTypeOfSymbol(prop); - if (prop.flags & SymbolFlags.EnumMember && isLiteralContextForType(node, propType)) { - propType = getDeclaredTypeOfSymbol(prop); - } // Only compute control flow type if this is a property access expression that isn't an // assignment target, and the referenced property was declared as a variable, property, @@ -12479,6 +12455,9 @@ namespace ts { } // Return a union of the return expression types. type = getUnionType(types, /*subtypeReduction*/ true); + if (isUnitType(type)) { + type = getBaseTypeOfLiteralType(type); + } if (funcIsGenerator) { type = createIterableIteratorType(type); @@ -12522,7 +12501,7 @@ namespace ts { return false; } const type = checkExpression(node.expression); - if (!isUnitUnionType(type)) { + if (!isLiteralType(type)) { return false; } const switchTypes = getSwitchClauseTypes(node); @@ -12865,7 +12844,7 @@ namespace ts { function checkPrefixUnaryExpression(node: PrefixUnaryExpression): Type { const operandType = checkExpression(node.operand); - if (node.operator === SyntaxKind.MinusToken && node.operand.kind === SyntaxKind.NumericLiteral && isLiteralContextForType(node, numberType)) { + if (node.operator === SyntaxKind.MinusToken && node.operand.kind === SyntaxKind.NumericLiteral) { return getLiteralTypeForText(TypeFlags.NumberLiteral, "" + -(node.operand).text); } switch (node.operator) { @@ -13265,11 +13244,11 @@ namespace ts { case SyntaxKind.ExclamationEqualsToken: case SyntaxKind.EqualsEqualsEqualsToken: case SyntaxKind.ExclamationEqualsEqualsToken: - const leftIsUnit = isUnitUnionType(leftType); - const rightIsUnit = isUnitUnionType(rightType); + const leftIsUnit = isLiteralType(leftType); + const rightIsUnit = isLiteralType(rightType); if (!leftIsUnit || !rightIsUnit) { - leftType = leftIsUnit ? getBaseTypeOfUnitType(leftType) : leftType; - rightType = rightIsUnit ? getBaseTypeOfUnitType(rightType) : rightType; + leftType = leftIsUnit ? getBaseTypeOfLiteralType(leftType) : leftType; + rightType = rightIsUnit ? getBaseTypeOfLiteralType(rightType) : rightType; } if (!isTypeEqualityComparableTo(leftType, rightType) && !isTypeEqualityComparableTo(rightType, leftType)) { reportOperatorError(); @@ -13281,7 +13260,7 @@ namespace ts { return checkInExpression(left, right, leftType, rightType); case SyntaxKind.AmpersandAmpersandToken: return getTypeFacts(leftType) & TypeFacts.Truthy ? - includeFalsyTypes(rightType, getFalsyFlags(strictNullChecks ? leftType : getBaseTypeOfUnitType(rightType))) : + includeFalsyTypes(rightType, getFalsyFlags(strictNullChecks ? leftType : getBaseTypeOfLiteralType(rightType))) : leftType; case SyntaxKind.BarBarToken: return getTypeFacts(leftType) & TypeFacts.Falsy ? @@ -13429,50 +13408,18 @@ namespace ts { return false; } - function isLiteralContextForType(node: Expression, type: Type) { - if (isLiteralTypeLocation(node)) { - return true; - } - let contextualType = getContextualType(node); - if (contextualType) { - if (contextualType.flags & TypeFlags.TypeParameter) { - const apparentType = getApparentTypeOfTypeParameter(contextualType); - // If the type parameter is constrained to the base primitive type we're checking for, - // consider this a literal context. For example, given a type parameter 'T extends string', - // this causes us to infer string literal types for T. - if (type === apparentType) { - return true; - } - contextualType = apparentType; - } - if (type.flags & TypeFlags.String) { - return maybeTypeOfKind(contextualType, TypeFlags.StringLiteral); - } - if (type.flags & TypeFlags.Number) { - return maybeTypeOfKind(contextualType, (TypeFlags.NumberLiteral | TypeFlags.EnumLiteral)); - } - if (type.flags & TypeFlags.Boolean) { - return maybeTypeOfKind(contextualType, TypeFlags.BooleanLiteral); - } - if (type.flags & TypeFlags.Enum) { - return typeContainsLiteralFromEnum(contextualType, type); - } - } - return false; - } - function checkLiteralExpression(node: Expression): Type { if (node.kind === SyntaxKind.NumericLiteral) { checkGrammarNumericLiteral(node); } switch (node.kind) { case SyntaxKind.StringLiteral: - return isLiteralContextForType(node, stringType) ? getLiteralTypeForText(TypeFlags.StringLiteral, (node).text) : stringType; + return getLiteralTypeForText(TypeFlags.StringLiteral, (node).text); case SyntaxKind.NumericLiteral: - return isLiteralContextForType(node, numberType) ? getLiteralTypeForText(TypeFlags.NumberLiteral, (node).text) : numberType; + return getLiteralTypeForText(TypeFlags.NumberLiteral, (node).text); case SyntaxKind.TrueKeyword: case SyntaxKind.FalseKeyword: - return isLiteralContextForType(node, booleanType) ? node.kind === SyntaxKind.TrueKeyword ? trueType : falseType : booleanType; + return node.kind === SyntaxKind.TrueKeyword ? trueType : falseType; } } @@ -13511,6 +13458,29 @@ namespace ts { return links.resolvedType; } + function hasLiteralContextualType(node: Expression) { + let contextualType = getContextualType(node); + if (contextualType) { + if (contextualType.flags & TypeFlags.TypeParameter) { + const apparentType = getApparentTypeOfTypeParameter(contextualType); + // If the type parameter is constrained to the base primitive type we're checking for, + // consider this a literal context. For example, given a type parameter 'T extends string', + // this causes us to infer string literal types for T. + if (apparentType.flags & (TypeFlags.String | TypeFlags.Number | TypeFlags.Boolean | TypeFlags.Enum)) { + return true; + } + contextualType = apparentType; + } + return maybeTypeOfKind(contextualType, TypeFlags.Literal); + } + return false; + } + + function checkExpressionForMutableLocation(node: Expression, contextualMapper?: TypeMapper): Type { + const type = checkExpression(node, contextualMapper); + return hasLiteralContextualType(node) ? type : getBaseTypeOfLiteralType(type); + } + function checkPropertyAssignment(node: PropertyAssignment, contextualMapper?: TypeMapper): Type { // Do not use hasDynamicName here, because that returns false for well known symbols. // We want to perform checkComputedPropertyName for all computed properties, including @@ -13519,7 +13489,7 @@ namespace ts { checkComputedPropertyName(node.name); } - return checkExpression((node).initializer, contextualMapper); + return checkExpressionForMutableLocation((node).initializer, contextualMapper); } function checkObjectLiteralMethod(node: MethodDeclaration, contextualMapper?: TypeMapper): Type {