diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index cdcf1448d2d..0a8c244cd53 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -8533,6 +8533,15 @@ namespace ts { return binarySearch(types, type, getTypeId, compareValues) >= 0; } + function insertType(types: Type[], type: Type): boolean { + const index = binarySearch(types, type, getTypeId, compareValues); + if (index < 0) { + types.splice(~index, 0, type); + return true; + } + return false; + } + // Return true if the given intersection type contains // more than one unit type or, // an object type and a nullable type (null or undefined), or @@ -8700,7 +8709,7 @@ namespace ts { includes & TypeFlags.Undefined ? includes & TypeFlags.NonWideningType ? undefinedType : undefinedWideningType : neverType; } - return getUnionTypeFromSortedList(typeSet, includes & TypeFlags.NotUnit ? 0 : TypeFlags.UnionOfUnitTypes, aliasSymbol, aliasTypeArguments); + return getUnionTypeFromSortedList(typeSet, includes & TypeFlags.NotPrimitiveUnion ? 0 : TypeFlags.UnionOfPrimitiveTypes, aliasSymbol, aliasTypeArguments); } function getUnionTypePredicate(signatures: ReadonlyArray): TypePredicate | undefined { @@ -8823,26 +8832,62 @@ namespace ts { } } - // When intersecting unions of unit types we can simply intersect based on type identity. - // Here we remove all unions of unit types from the given list and replace them with a - // a single union containing an intersection of the unit types. - function intersectUnionsOfUnitTypes(types: Type[]) { - const unionIndex = findIndex(types, t => (t.flags & TypeFlags.UnionOfUnitTypes) !== 0); - const unionType = types[unionIndex]; - let intersection = unionType.types; - let i = types.length - 1; - while (i > unionIndex) { + // Check that the given type has a match in every union. A given type is matched by + // an identical type, and a literal type is additionally matched by its corresponding + // primitive type. + function eachUnionContains(unionTypes: UnionType[], type: Type) { + for (const u of unionTypes) { + if (!containsType(u.types, type)) { + const primitive = type.flags & TypeFlags.StringLiteral ? stringType : + type.flags & TypeFlags.NumberLiteral ? numberType : + type.flags & TypeFlags.UniqueESSymbol ? esSymbolType : + undefined; + if (!primitive || !containsType(u.types, primitive)) { + return false; + } + } + } + return true; + } + + // Remove all unions of primitive types from the given list and replace them with a + // single union containing an intersection of those primitive types. + function intersectUnionsOfPrimitiveTypes(types: Type[]) { + let unionTypes: UnionType[] | undefined; + const index = findIndex(types, t => (t.flags & TypeFlags.UnionOfPrimitiveTypes) !== 0); + let i = index + 1; + // Remove all but the first union of primitive types and collect them in + // the unionTypes array. + while (i < types.length) { const t = types[i]; - if (t.flags & TypeFlags.UnionOfUnitTypes) { - intersection = filter(intersection, u => containsType((t).types, u)); + if (t.flags & TypeFlags.UnionOfPrimitiveTypes) { + (unionTypes || (unionTypes = [types[index]])).push(t); orderedRemoveItemAt(types, i); } - i--; + else { + i++; + } } - if (intersection === unionType.types) { + // Return false if there was only one union of primitive types + if (!unionTypes) { return false; } - types[unionIndex] = getUnionTypeFromSortedList(intersection, unionType.flags & TypeFlags.UnionOfUnitTypes); + // We have more than one union of primitive types, now intersect them. For each + // type in each union we check if the type is matched in every union and if so + // we include it in the result. + const checked: Type[] = []; + const result: Type[] = []; + for (const u of unionTypes) { + for (const t of u.types) { + if (insertType(checked, t)) { + if (eachUnionContains(unionTypes, t)) { + insertType(result, t); + } + } + } + } + // Finally replace the first union with the result + types[index] = getUnionTypeFromSortedList(result, TypeFlags.UnionOfPrimitiveTypes); return true; } @@ -8883,7 +8928,7 @@ namespace ts { return typeSet[0]; } if (includes & TypeFlags.Union) { - if (includes & TypeFlags.UnionOfUnitTypes && intersectUnionsOfUnitTypes(typeSet)) { + if (includes & TypeFlags.UnionOfPrimitiveTypes && intersectUnionsOfPrimitiveTypes(typeSet)) { // When the intersection creates a reduced set (which might mean that *all* union types have // disappeared), we restart the operation to get a new set of combined flags. Once we have // reduced we'll never reduce again, so this occurs at most once. @@ -13980,7 +14025,7 @@ namespace ts { if (type.flags & TypeFlags.Union) { const types = (type).types; const filtered = filter(types, f); - return filtered === types ? type : getUnionTypeFromSortedList(filtered, type.flags & TypeFlags.UnionOfUnitTypes); + return filtered === types ? type : getUnionTypeFromSortedList(filtered, type.flags & TypeFlags.UnionOfPrimitiveTypes); } return f(type) ? type : neverType; } diff --git a/src/compiler/types.ts b/src/compiler/types.ts index 54f85349bf1..e993dee90a3 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -3675,7 +3675,7 @@ namespace ts { /* @internal */ FreshLiteral = 1 << 25, // Fresh literal or unique type /* @internal */ - UnionOfUnitTypes = 1 << 26, // Type is union of unit types + UnionOfPrimitiveTypes = 1 << 26, // Type is union of primitive types /* @internal */ ContainsWideningType = 1 << 27, // Type is or contains undefined or null widening type /* @internal */ @@ -3720,7 +3720,7 @@ namespace ts { Narrowable = Any | Unknown | StructuredOrInstantiable | StringLike | NumberLike | BooleanLike | ESSymbol | UniqueESSymbol | NonPrimitive, NotUnionOrUnit = Any | Unknown | ESSymbol | Object | NonPrimitive, /* @internal */ - NotUnit = Any | String | Number | Boolean | Enum | ESSymbol | Void | Never | StructuredOrInstantiable, + NotPrimitiveUnion = Any | Unknown | Enum | Void | Never | StructuredOrInstantiable, /* @internal */ RequiresWidening = ContainsWideningType | ContainsObjectLiteral, /* @internal */