diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 33c7962b647..2c670397fe2 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -18787,6 +18787,14 @@ namespace ts { signature.declaration && (getReturnTypeFromAnnotation(signature.declaration) || unknownType).flags & TypeFlags.Never); } + function getTypePredicateArgument(predicate: TypePredicate, callExpression: CallExpression) { + if (predicate.kind === TypePredicateKind.Identifier || predicate.kind === TypePredicateKind.AssertsIdentifier) { + return callExpression.arguments[predicate.parameterIndex]; + } + const invokedExpression = skipParentheses(callExpression.expression); + return isAccessExpression(invokedExpression) ? skipParentheses(invokedExpression.expression) : undefined; + } + function reportFlowControlError(node: Node) { const block = findAncestor(node, isFunctionOrModuleBlock); const sourceFile = getSourceFileOfNode(node); @@ -19338,6 +19346,9 @@ namespace ts { if (isMatchingReference(reference, expr)) { return getTypeWithFacts(type, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy); } + if (strictNullChecks && assumeTrue && optionalChainContainsReference(expr, reference)) { + type = getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull); + } if (isMatchingReferenceDiscriminant(expr, declaredType)) { return narrowTypeByDiscriminant(type, expr, t => getTypeWithFacts(t, assumeTrue ? TypeFacts.Truthy : TypeFacts.Falsy)); } @@ -19422,21 +19433,13 @@ namespace ts { } function narrowTypeByOptionalChainContainment(type: Type, operator: SyntaxKind, value: Expression, assumeTrue: boolean): Type { - const op = assumeTrue ? operator : - operator === SyntaxKind.EqualsEqualsToken ? SyntaxKind.ExclamationEqualsToken : - operator === SyntaxKind.EqualsEqualsEqualsToken ? SyntaxKind.ExclamationEqualsEqualsToken : - operator === SyntaxKind.ExclamationEqualsToken ? SyntaxKind.EqualsEqualsToken : - operator === SyntaxKind.ExclamationEqualsEqualsToken ? SyntaxKind.EqualsEqualsEqualsToken : - operator; // We are in a branch of obj?.foo === value or obj?.foo !== value. We remove undefined and null from // the type of obj if (a) the operator is === and the type of value doesn't include undefined or (b) the // operator is !== and the type of value is undefined. - const valueType = getTypeOfExpression(value); - return op === SyntaxKind.EqualsEqualsToken && !(getTypeFacts(valueType) & TypeFacts.EQUndefinedOrNull) || - op === SyntaxKind.EqualsEqualsEqualsToken && !(getTypeFacts(valueType) & TypeFacts.EQUndefined) || - op === SyntaxKind.ExclamationEqualsToken && valueType.flags & TypeFlags.Nullable || - op === SyntaxKind.ExclamationEqualsEqualsToken && valueType.flags & TypeFlags.Undefined ? - getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull) : type; + const effectiveTrue = operator === SyntaxKind.EqualsEqualsToken || operator === SyntaxKind.EqualsEqualsEqualsToken ? assumeTrue : !assumeTrue; + const doubleEquals = operator === SyntaxKind.EqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsToken; + const valueNonNullish = !(getTypeFacts(getTypeOfExpression(value)) & (doubleEquals ? TypeFacts.EQUndefinedOrNull : TypeFacts.EQUndefined)); + return effectiveTrue === valueNonNullish ? getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull) : type; } function narrowTypeByEquality(type: Type, operator: SyntaxKind, value: Expression, assumeTrue: boolean): Type { @@ -19487,10 +19490,12 @@ namespace ts { function narrowTypeByTypeof(type: Type, typeOfExpr: TypeOfExpression, operator: SyntaxKind, literal: LiteralExpression, assumeTrue: boolean): Type { // We have '==', '!=', '===', or !==' operator with 'typeof xxx' and string literal operands + if (operator === SyntaxKind.ExclamationEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) { + assumeTrue = !assumeTrue; + } const target = getReferenceCandidate(typeOfExpr.expression); if (!isMatchingReference(reference, target)) { - if (assumeTrue && (operator === SyntaxKind.EqualsEqualsToken || operator === SyntaxKind.EqualsEqualsEqualsToken) && - strictNullChecks && optionalChainContainsReference(target, reference)) { + if (strictNullChecks && optionalChainContainsReference(target, reference) && assumeTrue === (literal.text !== "undefined")) { return getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull); } // For a reference of the form 'x.y', a 'typeof x === ...' type guard resets the @@ -19500,9 +19505,6 @@ namespace ts { } return type; } - if (operator === SyntaxKind.ExclamationEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) { - assumeTrue = !assumeTrue; - } if (type.flags & TypeFlags.Any && literal.text === "function") { return type; } @@ -19763,32 +19765,20 @@ namespace ts { function narrowTypeByTypePredicate(type: Type, predicate: TypePredicate, callExpression: CallExpression, assumeTrue: boolean): Type { // Don't narrow from 'any' if the predicate type is exactly 'Object' or 'Function' - if (isTypeAny(type) && (predicate.type === globalObjectType || predicate.type === globalFunctionType)) { - return type; - } - if (predicate.kind === TypePredicateKind.Identifier || predicate.kind === TypePredicateKind.AssertsIdentifier) { - const predicateArgument = callExpression.arguments[predicate.parameterIndex]; - if (predicateArgument && predicate.type) { + if (predicate.type && !(isTypeAny(type) && (predicate.type === globalObjectType || predicate.type === globalFunctionType))) { + const predicateArgument = getTypePredicateArgument(predicate, callExpression); + if (predicateArgument) { if (isMatchingReference(reference, predicateArgument)) { return getNarrowedType(type, predicate.type, assumeTrue, isTypeSubtypeOf); } + if (strictNullChecks && assumeTrue && !(getTypeFacts(predicate.type) & TypeFacts.EQUndefined)) { + type = getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull); + } if (containsMatchingReference(reference, predicateArgument)) { return declaredType; } } } - else { - const invokedExpression = skipParentheses(callExpression.expression); - if (isAccessExpression(invokedExpression) && predicate.type) { - const possibleReference = skipParentheses(invokedExpression.expression); - if (isMatchingReference(reference, possibleReference)) { - return getNarrowedType(type, predicate.type, assumeTrue, isTypeSubtypeOf); - } - if (containsMatchingReference(reference, possibleReference)) { - return declaredType; - } - } - } return type; }