From 7dd64d3ea26c9642ffa8e1f82bd0568b664403ae Mon Sep 17 00:00:00 2001 From: Anders Hejlsberg Date: Thu, 13 Oct 2016 06:29:34 -0700 Subject: [PATCH] Properly narrow union types containing string and number --- src/compiler/checker.ts | 32 ++++++++++++++++++++++++-------- src/compiler/types.ts | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 6555710137a..dffe4244e39 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -8461,6 +8461,28 @@ namespace ts { return f(type) ? type : neverType; } + function mapType(type: Type, f: (t: Type) => Type): Type { + return type.flags & TypeFlags.Union ? getUnionType(map((type).types, f)) : f(type); + } + + function extractTypesOfKind(type: Type, kind: TypeFlags) { + return filterType(type, t => (t.flags & kind) !== 0); + } + + // Return a new type in which occurrences of the string and number primitive types in + // typeWithPrimitives have been replaced with occurrences of string literals and numeric + // literals in typeWithLiterals, respectively. + function replacePrimitivesWithLiterals(typeWithPrimitives: Type, typeWithLiterals: Type) { + if (isTypeSubsetOf(stringType, typeWithPrimitives) && maybeTypeOfKind(typeWithLiterals, TypeFlags.StringLiteral) || + isTypeSubsetOf(numberType, typeWithPrimitives) && maybeTypeOfKind(typeWithLiterals, TypeFlags.NumberLiteral)) { + return mapType(typeWithPrimitives, t => + t.flags & TypeFlags.String ? extractTypesOfKind(typeWithLiterals, TypeFlags.String | TypeFlags.StringLiteral) : + t.flags & TypeFlags.Number ? extractTypesOfKind(typeWithLiterals, TypeFlags.Number | TypeFlags.NumberLiteral) : + t); + } + return typeWithPrimitives; + } + function isIncomplete(flowType: FlowType) { return flowType.flags === 0; } @@ -8791,16 +8813,12 @@ namespace ts { assumeTrue ? TypeFacts.EQUndefined : TypeFacts.NEUndefined; return getTypeWithFacts(type, facts); } - if (type.flags & TypeFlags.String && isTypeOfKind(valueType, TypeFlags.StringLiteral) || - type.flags & TypeFlags.Number && isTypeOfKind(valueType, TypeFlags.NumberLiteral)) { - return assumeTrue? valueType : type; - } if (type.flags & TypeFlags.NotUnionOrUnit) { return type; } if (assumeTrue) { const narrowedType = filterType(type, t => areTypesComparable(t, valueType)); - return narrowedType.flags & TypeFlags.Never ? type : narrowedType; + return narrowedType.flags & TypeFlags.Never ? type : replacePrimitivesWithLiterals(narrowedType, valueType); } if (isUnitType(valueType)) { const regularType = getRegularTypeOfLiteralType(valueType); @@ -8849,9 +8867,7 @@ namespace ts { const discriminantType = getUnionType(clauseTypes); const caseType = discriminantType.flags & TypeFlags.Never ? neverType : - type.flags & TypeFlags.String && isTypeOfKind(discriminantType, TypeFlags.StringLiteral) ? discriminantType : - type.flags & TypeFlags.Number && isTypeOfKind(discriminantType, TypeFlags.NumberLiteral) ? discriminantType : - filterType(type, t => isTypeComparableTo(discriminantType, t)); + replacePrimitivesWithLiterals(filterType(type, t => isTypeComparableTo(discriminantType, t)), discriminantType); if (!hasDefaultClause) { return caseType; } diff --git a/src/compiler/types.ts b/src/compiler/types.ts index 638504e613f..ba11b35459b 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -2644,7 +2644,7 @@ namespace ts { // 'Narrowable' types are types where narrowing actually narrows. // This *should* be every type other than null, undefined, void, and never Narrowable = Any | StructuredType | TypeParameter | StringLike | NumberLike | BooleanLike | ESSymbol, - NotUnionOrUnit = Any | String | Number | ESSymbol | ObjectType, + NotUnionOrUnit = Any | ESSymbol | ObjectType, /* @internal */ RequiresWidening = ContainsWideningType | ContainsObjectLiteral, /* @internal */