diff --git a/src/compiler/binder.ts b/src/compiler/binder.ts index b45af24d4d1..27bd3b71c5e 100644 --- a/src/compiler/binder.ts +++ b/src/compiler/binder.ts @@ -637,9 +637,8 @@ namespace ts { case SyntaxKind.ExclamationEqualsToken: case SyntaxKind.EqualsEqualsEqualsToken: case SyntaxKind.ExclamationEqualsEqualsToken: - return isNarrowingNullCheckOperands(expr.right, expr.left) || isNarrowingNullCheckOperands(expr.left, expr.right) || - isNarrowingTypeofOperands(expr.right, expr.left) || isNarrowingTypeofOperands(expr.left, expr.right) || - isNarrowingDiscriminant(expr.left) || isNarrowingDiscriminant(expr.right); + return isNarrowableOperand(expr.left) || isNarrowableOperand(expr.right) || + isNarrowingTypeofOperands(expr.right, expr.left) || isNarrowingTypeofOperands(expr.left, expr.right); case SyntaxKind.InstanceOfKeyword: return isNarrowableOperand(expr.left); case SyntaxKind.CommaToken: @@ -663,11 +662,6 @@ namespace ts { return isNarrowableReference(expr); } - function isNarrowingSwitchStatement(switchStatement: SwitchStatement) { - const expr = switchStatement.expression; - return expr.kind === SyntaxKind.PropertyAccessExpression && isNarrowableReference((expr).expression); - } - function createBranchLabel(): FlowLabel { return { flags: FlowFlags.BranchLabel, @@ -717,7 +711,7 @@ namespace ts { } function createFlowSwitchClause(antecedent: FlowNode, switchStatement: SwitchStatement, clauseStart: number, clauseEnd: number): FlowNode { - if (!isNarrowingSwitchStatement(switchStatement)) { + if (!isNarrowingExpression(switchStatement.expression)) { return antecedent; } setFlowNodeReferenced(antecedent); diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index ab5d4e9a5c4..82c37b159ec 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -6860,6 +6860,11 @@ namespace ts { return !!getPropertyOfType(type, "0"); } + function isUnitType(type: Type): boolean { + return type.flags & (TypeFlags.Literal | TypeFlags.Void | TypeFlags.Undefined | TypeFlags.Null) || + type.flags & TypeFlags.Enum && type.symbol.flags & SymbolFlags.EnumMember ? true : false; + } + function isLiteralUnionType(type: Type): boolean { return type.flags & TypeFlags.Literal ? true : type.flags & TypeFlags.Enum ? (type.symbol.flags & SymbolFlags.EnumMember) !== 0 : @@ -7481,11 +7486,6 @@ namespace ts { return undefined; } - function isNullOrUndefinedLiteral(node: Expression) { - return node.kind === SyntaxKind.NullKeyword || - node.kind === SyntaxKind.Identifier && getResolvedSymbol(node) === undefinedSymbol; - } - function getLeftmostIdentifierOrThis(node: Node): Node { switch (node.kind) { case SyntaxKind.Identifier: @@ -7584,13 +7584,16 @@ namespace ts { type === emptyStringType ? TypeFacts.EmptyStringStrictFacts : TypeFacts.NonEmptyStringStrictFacts : type === emptyStringType ? TypeFacts.EmptyStringFacts : TypeFacts.NonEmptyStringFacts; } - if (flags & TypeFlags.Number) { + if (flags & TypeFlags.Number || type.flags & TypeFlags.Enum && !(type.symbol.flags & SymbolFlags.EnumMember)) { return strictNullChecks ? TypeFacts.NumberStrictFacts : TypeFacts.NumberFacts; } if (flags & TypeFlags.NumberLike) { + const isZero = type === zeroType || + type.flags & TypeFlags.Enum && type.symbol.flags & SymbolFlags.EnumMember && + getEnumMemberValue(type.symbol.valueDeclaration) === 0; return strictNullChecks ? - type === zeroType ? TypeFacts.ZeroStrictFacts : TypeFacts.NonZeroStrictFacts : - type === zeroType ? TypeFacts.ZeroFacts : TypeFacts.NonZeroFacts; + isZero ? TypeFacts.ZeroStrictFacts : TypeFacts.NonZeroStrictFacts : + isZero ? TypeFacts.ZeroFacts : TypeFacts.NonZeroFacts; } if (flags & TypeFlags.Boolean) { return strictNullChecks ? TypeFacts.BooleanStrictFacts : TypeFacts.BooleanFacts; @@ -7756,23 +7759,27 @@ namespace ts { getInitialTypeOfBindingElement(node); } - function getReferenceFromExpression(node: Expression): Expression { + function getReferenceCandidate(node: Expression): Expression { switch (node.kind) { case SyntaxKind.ParenthesizedExpression: - return getReferenceFromExpression((node).expression); + return getReferenceCandidate((node).expression); case SyntaxKind.BinaryExpression: switch ((node).operatorToken.kind) { case SyntaxKind.EqualsToken: - return getReferenceFromExpression((node).left); + return getReferenceCandidate((node).left); case SyntaxKind.CommaToken: - return getReferenceFromExpression((node).right); + return getReferenceCandidate((node).right); } } return node; } function getTypeOfSwitchClause(clause: CaseClause | DefaultClause) { - return clause.kind === SyntaxKind.CaseClause ? checkExpression((clause).expression) : undefined; + if (clause.kind === SyntaxKind.CaseClause) { + const caseType = checkExpression((clause).expression); + return isUnitType(caseType) ? caseType : undefined; + } + return neverType; } function getSwitchClauseTypes(switchStatement: SwitchStatement): Type[] { @@ -7781,7 +7788,7 @@ namespace ts { // If all case clauses specify expressions that have unit types, we return an array // of those unit types. Otherwise we return an empty array. const types = map(switchStatement.caseBlock.clauses, getTypeOfSwitchClause); - links.switchTypes = forEach(types, t => !t || isLiteralUnionType(t)) ? types : emptyArray; + links.switchTypes = !contains(types, undefined) ? types : emptyArray; } return links.switchTypes; } @@ -7921,7 +7928,14 @@ namespace ts { function getTypeAtSwitchClause(flow: FlowSwitchClause) { const type = getTypeAtFlowNode(flow.antecedent); - return narrowTypeBySwitchOnDiscriminant(type, flow.switchStatement, flow.clauseStart, flow.clauseEnd); + const expr = flow.switchStatement.expression; + if (isMatchingReference(reference, expr)) { + return narrowTypeBySwitchOnDiscriminant(type, flow.switchStatement, flow.clauseStart, flow.clauseEnd); + } + if (isMatchingPropertyAccess(expr)) { + return narrowTypeByDiscriminant(type, expr, t => narrowTypeBySwitchOnDiscriminant(t, flow.switchStatement, flow.clauseStart, flow.clauseEnd)); + } + return type; } function getTypeAtFlowBranchLabel(flow: FlowLabel) { @@ -7991,8 +8005,27 @@ namespace ts { return cache[key] = getUnionType(antecedentTypes); } + function isMatchingPropertyAccess(expr: Expression) { + return expr.kind === SyntaxKind.PropertyAccessExpression && + isMatchingReference(reference, (expr).expression) && + (declaredType.flags & TypeFlags.Union) !== 0; + } + + function narrowTypeByDiscriminant(type: Type, propAccess: PropertyAccessExpression, narrowType: (t: Type) => Type): Type { + const propName = propAccess.name.text; + const propType = getTypeOfPropertyOfType(type, propName); + const narrowedPropType = propType && narrowType(propType); + return propType === narrowedPropType ? type : filterType(type, t => isTypeComparableTo(getTypeOfPropertyOfType(t, propName), narrowedPropType)); + } + function narrowTypeByTruthiness(type: Type, expr: Expression, assumeTrue: boolean): Type { - return isMatchingReference(reference, expr) ? getTypeWithFacts(type, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy) : type; + if (isMatchingReference(reference, expr)) { + return getTypeWithFacts(type, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy); + } + if (isMatchingPropertyAccess(expr)) { + return narrowTypeByDiscriminant(type, expr, t => getTypeWithFacts(t, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy)); + } + return type; } function narrowTypeByBinaryExpression(type: Type, expr: BinaryExpression, assumeTrue: boolean): Type { @@ -8003,26 +8036,26 @@ namespace ts { case SyntaxKind.ExclamationEqualsToken: case SyntaxKind.EqualsEqualsEqualsToken: case SyntaxKind.ExclamationEqualsEqualsToken: - const left = expr.left; const operator = expr.operatorToken.kind; - const right = expr.right; - if (isNullOrUndefinedLiteral(right)) { - return narrowTypeByNullCheck(type, left, operator, right, assumeTrue); - } - if (isNullOrUndefinedLiteral(left)) { - return narrowTypeByNullCheck(type, right, operator, left, assumeTrue); - } + const left = getReferenceCandidate(expr.left); + const right = getReferenceCandidate(expr.right); if (left.kind === SyntaxKind.TypeOfExpression && right.kind === SyntaxKind.StringLiteral) { return narrowTypeByTypeof(type, left, operator, right, assumeTrue); } if (right.kind === SyntaxKind.TypeOfExpression && left.kind === SyntaxKind.StringLiteral) { return narrowTypeByTypeof(type, right, operator, left, assumeTrue); } - if (left.kind === SyntaxKind.PropertyAccessExpression) { - return narrowTypeByDiscriminant(type, left, operator, right, assumeTrue); + if (isMatchingReference(reference, left)) { + return narrowTypeByEquality(type, operator, right, assumeTrue); } - if (right.kind === SyntaxKind.PropertyAccessExpression) { - return narrowTypeByDiscriminant(type, right, operator, left, assumeTrue); + if (isMatchingReference(reference, right)) { + return narrowTypeByEquality(type, operator, left, assumeTrue); + } + if (isMatchingPropertyAccess(left)) { + return narrowTypeByDiscriminant(type, left, t => narrowTypeByEquality(t, operator, right, assumeTrue)); + } + if (isMatchingPropertyAccess(right)) { + return narrowTypeByDiscriminant(type, right, t => narrowTypeByEquality(t, operator, left, assumeTrue)); } break; case SyntaxKind.InstanceOfKeyword: @@ -8033,26 +8066,36 @@ namespace ts { return type; } - function narrowTypeByNullCheck(type: Type, target: Expression, operator: SyntaxKind, literal: Expression, assumeTrue: boolean): Type { - // We have '==', '!=', '===', or '!==' operator with 'null' or 'undefined' as value + function narrowTypeByEquality(type: Type, operator: SyntaxKind, value: Expression, assumeTrue: boolean): Type { if (operator === SyntaxKind.ExclamationEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) { assumeTrue = !assumeTrue; } - if (!strictNullChecks || !isMatchingReference(reference, getReferenceFromExpression(target))) { + const valueType = checkExpression(value); + if (valueType.flags & TypeFlags.Nullable) { + if (!strictNullChecks) { + return type; + } + const doubleEquals = operator === SyntaxKind.EqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsToken; + const facts = doubleEquals ? + assumeTrue ? TypeFacts.EQUndefinedOrNull : TypeFacts.NEUndefinedOrNull : + value.kind === SyntaxKind.NullKeyword ? + assumeTrue ? TypeFacts.EQNull : TypeFacts.NENull : + assumeTrue ? TypeFacts.EQUndefined : TypeFacts.NEUndefined; + return getTypeWithFacts(type, facts); + } + if (type.flags & TypeFlags.NotUnionOrUnit) { return type; } - const doubleEquals = operator === SyntaxKind.EqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsToken; - const facts = doubleEquals ? - assumeTrue ? TypeFacts.EQUndefinedOrNull : TypeFacts.NEUndefinedOrNull : - literal.kind === SyntaxKind.NullKeyword ? - assumeTrue ? TypeFacts.EQNull : TypeFacts.NENull : - assumeTrue ? TypeFacts.EQUndefined : TypeFacts.NEUndefined; - return getTypeWithFacts(type, facts); + if (assumeTrue) { + const narrowedType = filterType(type, t => areTypesComparable(t, valueType)); + return narrowedType !== neverType ? narrowedType : type; + } + return isUnitType(valueType) ? filterType(type, t => t !== valueType) : type; } function narrowTypeByTypeof(type: Type, typeOfExpr: TypeOfExpression, operator: SyntaxKind, literal: LiteralExpression, assumeTrue: boolean): Type { // We have '==', '!=', '====', or !==' operator with 'typeof xxx' and string literal operands - const target = getReferenceFromExpression(typeOfExpr.expression); + const target = getReferenceCandidate(typeOfExpr.expression); if (!isMatchingReference(reference, target)) { // For a reference of the form 'x.y', a 'typeof x === ...' type guard resets the // narrowed type of 'y' to its declared type. @@ -8079,40 +8122,8 @@ namespace ts { return getTypeWithFacts(type, facts); } - function narrowTypeByDiscriminant(type: Type, propAccess: PropertyAccessExpression, operator: SyntaxKind, value: Expression, assumeTrue: boolean): Type { - // We have '==', '!=', '===', or '!==' operator with property access as target - if (!isMatchingReference(reference, propAccess.expression)) { - return type; - } - const propName = propAccess.name.text; - const propType = getTypeOfPropertyOfType(type, propName); - if (!propType || !isLiteralUnionType(propType)) { - return type; - } - const discriminantType = checkExpression(value); - if (!isLiteralUnionType(discriminantType)) { - return type; - } - if (operator === SyntaxKind.ExclamationEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) { - assumeTrue = !assumeTrue; - } - if (assumeTrue) { - return filterType(type, t => areTypesComparable(getTypeOfPropertyOfType(t, propName), discriminantType)); - } - if (isLiteralUnionType(discriminantType) && !(discriminantType.flags & TypeFlags.Union)) { - return filterType(type, t => getTypeOfPropertyOfType(t, propName) !== discriminantType); - } - return type; - } - function narrowTypeBySwitchOnDiscriminant(type: Type, switchStatement: SwitchStatement, clauseStart: number, clauseEnd: number) { - // We have switch statement with property access expression - if (!isMatchingReference(reference, (switchStatement.expression).expression)) { - return type; - } - const propName = (switchStatement.expression).name.text; - const propType = getTypeOfPropertyOfType(type, propName); - if (!propType || !isLiteralUnionType(propType)) { + if (!isLiteralUnionType(type)) { return type; } const switchTypes = getSwitchClauseTypes(switchStatement); @@ -8120,19 +8131,18 @@ namespace ts { return type; } const clauseTypes = switchTypes.slice(clauseStart, clauseEnd); - const hasDefaultClause = clauseStart === clauseEnd || contains(clauseTypes, undefined); - const caseTypes = hasDefaultClause ? filter(clauseTypes, t => !!t) : clauseTypes; - const discriminantType = caseTypes.length ? getUnionType(caseTypes) : undefined; - const caseType = discriminantType && filterType(type, t => isTypeComparableTo(discriminantType, getTypeOfPropertyOfType(t, propName))); + const hasDefaultClause = clauseStart === clauseEnd || contains(clauseTypes, neverType); + const discriminantType = getUnionType(clauseTypes); + const caseType = discriminantType === neverType ? neverType : filterType(type, t => isTypeComparableTo(discriminantType, t)); if (!hasDefaultClause) { return caseType; } - const defaultType = filterType(type, t => !eachTypeContainedIn(getTypeOfPropertyOfType(t, propName), switchTypes)); - return caseType ? getUnionType([caseType, defaultType]) : defaultType; + const defaultType = filterType(type, t => !eachTypeContainedIn(t, switchTypes)); + return caseType === neverType ? defaultType : getUnionType([caseType, defaultType]); } function narrowTypeByInstanceof(type: Type, expr: BinaryExpression, assumeTrue: boolean): Type { - const left = getReferenceFromExpression(expr.left); + const left = getReferenceCandidate(expr.left); if (!isMatchingReference(reference, left)) { // For a reference of the form 'x.y', an 'x instanceof T' type guard resets the // narrowed type of 'y' to its declared type. @@ -11956,24 +11966,18 @@ namespace ts { } function isExhaustiveSwitchStatement(node: SwitchStatement): boolean { - const expr = node.expression; - if (!node.possiblyExhaustive || expr.kind !== SyntaxKind.PropertyAccessExpression) { + if (!node.possiblyExhaustive) { return false; } - const type = checkExpression((expr).expression); - if (!(type.flags & TypeFlags.Union)) { - return false; - } - const propName = (expr).name.text; - const propType = getTypeOfPropertyOfType(type, propName); - if (!propType || !isLiteralUnionType(propType)) { + const type = checkExpression(node.expression); + if (!isLiteralUnionType(type)) { return false; } const switchTypes = getSwitchClauseTypes(node); if (!switchTypes.length) { return false; } - return eachTypeContainedIn(propType, switchTypes); + return eachTypeContainedIn(type, switchTypes); } function functionHasImplicitReturn(func: FunctionLikeDeclaration) { diff --git a/src/compiler/types.ts b/src/compiler/types.ts index c49b5578c80..486f1e82a97 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -2243,7 +2243,7 @@ namespace ts { /* @internal */ Intrinsic = Any | String | Number | Boolean | BooleanLiteral | ESSymbol | Void | Undefined | Null | Never, /* @internal */ - Primitive = String | Number | Boolean | ESSymbol | Void | Undefined | Null | StringLiteral | Enum, + Primitive = String | Number | Boolean | ESSymbol | Void | Undefined | Null | Literal | Enum, StringLike = String | StringLiteral, NumberLike = Number | NumberLiteral | Enum, BooleanLike = Boolean | BooleanLiteral, @@ -2253,7 +2253,8 @@ namespace ts { // 'Narrowable' types are types where narrowing actually narrows. // This *should* be every type other than null, undefined, void, and never - Narrowable = Any | StructuredType | TypeParameter | StringLike | NumberLike | Boolean | ESSymbol, + Narrowable = Any | StructuredType | TypeParameter | StringLike | NumberLike | BooleanLike | ESSymbol, + NotUnionOrUnit = Any | String | Number | Boolean | ESSymbol | ObjectType, /* @internal */ RequiresWidening = ContainsWideningType | ContainsObjectLiteral, /* @internal */