diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 6ece23f022a..5f0f8241a5e 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -18124,6 +18124,16 @@ namespace ts { return false; } + function optionalChainContainsReference(source: Node, target: Node) { + while (isOptionalChain(source)) { + source = source.expression; + if (isMatchingReference(source, target)) { + return true; + } + } + return false; + } + // Return true if target is a property access xxx.yyy, source is a property access xxx.zzz, the declared // type of xxx is a union type, and yyy is a property that is possibly a discriminant. We consider a property // a possible discriminant if its type differs in the constituents of containing union type, and if every @@ -19350,6 +19360,14 @@ namespace ts { if (isMatchingReference(reference, right)) { return narrowTypeByEquality(type, operator, left, assumeTrue); } + if (assumeTrue && (operator === SyntaxKind.EqualsEqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken)) { + if (optionalChainContainsReference(left, reference)) { + type = narrowTypeByOptionalChainContainment(type, operator, right); + } + else if (optionalChainContainsReference(right, reference)) { + type = narrowTypeByOptionalChainContainment(type, operator, left); + } + } if (isMatchingReferenceDiscriminant(left, declaredType)) { return narrowTypeByDiscriminant(type, left, t => narrowTypeByEquality(t, operator, right, assumeTrue)); } @@ -19374,6 +19392,16 @@ namespace ts { return type; } + function narrowTypeByOptionalChainContainment(type: Type, operator: SyntaxKind, value: Expression): Type { + // We are in the true branch of obj?.foo === value or obj?.foo !== value. We remove undefined and null from + // the type of obj if (a) the operator is === and the type of value doesn't include undefined or (b) the + // operator is !== and the type of value is undefined. + const valueType = getTypeOfExpression(value); + return operator === SyntaxKind.EqualsEqualsEqualsToken && !(getTypeFacts(valueType) & TypeFacts.EQUndefined) || + operator === SyntaxKind.ExclamationEqualsEqualsToken && valueType.flags & TypeFlags.Undefined ? + getTypeWithFacts(type, TypeFacts.NEUndefinedOrNull) : type; + } + function narrowTypeByEquality(type: Type, operator: SyntaxKind, value: Expression, assumeTrue: boolean): Type { if (type.flags & TypeFlags.Any) { return type;