From ff3b627ca612bcdc9bcf1ed0ff5b50275f4c2020 Mon Sep 17 00:00:00 2001 From: Anders Hejlsberg Date: Tue, 6 Sep 2016 17:25:02 -0700 Subject: [PATCH] Less widening of literal types in type inference --- src/compiler/checker.ts | 101 ++++++++++++++++++++++++++++------------ src/compiler/types.ts | 2 + 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 937dc5e76fc..3ba32ebfd0a 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -7131,15 +7131,36 @@ namespace ts { return true; } + function literalTypesWithSameBaseType(types: Type[]): boolean { + let commonBaseType: Type; + for (const t of types) { + const baseType = getBaseTypeOfLiteralType(t); + if (!commonBaseType) { + commonBaseType = baseType; + } + if (baseType === t || baseType !== commonBaseType) { + return false; + } + } + return true; + } + + // When the candidate types are all literal types with the same base type, the common + // supertype is a union of those literal types. Otherwise, the common supertype is the + // first type that is a supertype of each of the other types. + function getSupertypeOrUnion(types: Type[]): Type { + return literalTypesWithSameBaseType(types) ? getUnionType(types) : forEach(types, t => isSupertypeOfEach(t, types) ? t : undefined); + } + function getCommonSupertype(types: Type[]): Type { if (!strictNullChecks) { - return forEach(types, t => isSupertypeOfEach(t, types) ? t : undefined); + return getSupertypeOrUnion(types); } const primaryTypes = filter(types, t => !(t.flags & TypeFlags.Nullable)); if (!primaryTypes.length) { return getUnionType(types, /*subtypeReduction*/ true); } - const supertype = forEach(primaryTypes, t => isSupertypeOfEach(t, primaryTypes) ? t : undefined); + const supertype = getSupertypeOrUnion(primaryTypes); return supertype && includeFalsyTypes(supertype, getFalsyFlagsOfTypes(types) & TypeFlags.Nullable); } @@ -7468,11 +7489,13 @@ namespace ts { } } - function createInferenceContext(typeParameters: TypeParameter[], inferUnionTypes: boolean): InferenceContext { - const inferences = map(typeParameters, createTypeInferencesObject); - + function createInferenceContext(signature: Signature, inferUnionTypes: boolean): InferenceContext { + const typeParameters = signature.typeParameters; + const returnType = getReturnTypeOfSignature(signature); + const inferences = map(signature.typeParameters, createTypeInferencesObject); return { typeParameters, + returnType, inferUnionTypes, inferences, inferredTypes: new Array(typeParameters.length), @@ -7483,6 +7506,7 @@ namespace ts { return { primary: undefined, secondary: undefined, + shallow: true, isFixed: false, }; } @@ -7504,21 +7528,13 @@ namespace ts { return type.couldContainTypeParameters; } - function hasPrimitiveConstraint(type: TypeParameter): boolean { - const constraint = getConstraintOfTypeParameter(type); - return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive); - } - function inferTypes(context: InferenceContext, source: Type, target: Type) { let sourceStack: Type[]; let targetStack: Type[]; let depth = 0; let inferiority = 0; const visited = createMap(); - // We widen a literal source type only if we're inferring directly to a type parameter - // that has no primitive or literal constraint. - const shouldWiden = isLiteralType(source) && target.flags & TypeFlags.TypeParameter && !hasPrimitiveConstraint(target); - inferFromTypes(shouldWiden ? getBaseTypeOfLiteralType(source) : source, target); + inferFromTypes(source, target, /*nested*/ false); function isInProcess(source: Type, target: Type) { for (let i = 0; i < depth; i++) { @@ -7529,7 +7545,7 @@ namespace ts { return false; } - function inferFromTypes(source: Type, target: Type) { + function inferFromTypes(source: Type, target: Type, nested: boolean) { if (!couldContainTypeParameters(target)) { return; } @@ -7539,7 +7555,7 @@ namespace ts { // are the same type, just relate each constituent type to itself. if (source === target) { for (const t of (source).types) { - inferFromTypes(t, t); + inferFromTypes(t, t, /*nested*/ false); } return; } @@ -7551,7 +7567,7 @@ namespace ts { for (const t of (target).types) { if (typeIdenticalToSomeType(t, (source).types)) { (matchingTypes || (matchingTypes = [])).push(t); - inferFromTypes(t, t); + inferFromTypes(t, t, /*nested*/ false); } } // Next, to improve the quality of inferences, reduce the source and target types by @@ -7589,6 +7605,9 @@ namespace ts { if (!contains(candidates, source)) { candidates.push(source); } + if (nested) { + inferences.shallow = false; + } } return; } @@ -7600,7 +7619,7 @@ namespace ts { const targetTypes = (target).typeArguments || emptyArray; const count = sourceTypes.length < targetTypes.length ? sourceTypes.length : targetTypes.length; for (let i = 0; i < count; i++) { - inferFromTypes(sourceTypes[i], targetTypes[i]); + inferFromTypes(sourceTypes[i], targetTypes[i], /*nested*/ true); } } else if (target.flags & TypeFlags.UnionOrIntersection) { @@ -7614,7 +7633,7 @@ namespace ts { typeParameterCount++; } else { - inferFromTypes(source, t); + inferFromTypes(source, t, /*nested*/ false); } } // Next, if target containings a single naked type parameter, make a secondary inference to that type @@ -7622,7 +7641,7 @@ namespace ts { // types in contra-variant positions (such as callback parameters). if (typeParameterCount === 1) { inferiority++; - inferFromTypes(source, typeParameter); + inferFromTypes(source, typeParameter, /*nested*/ false); inferiority--; } } @@ -7630,7 +7649,7 @@ namespace ts { // Source is a union or intersection type, infer from each constituent type const sourceTypes = (source).types; for (const sourceType of sourceTypes) { - inferFromTypes(sourceType, target); + inferFromTypes(sourceType, target, /*nested*/ false); } } else { @@ -7668,7 +7687,7 @@ namespace ts { for (const targetProp of properties) { const sourceProp = getPropertyOfObjectType(source, targetProp.name); if (sourceProp) { - inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp)); + inferFromTypes(getTypeOfSymbol(sourceProp), getTypeOfSymbol(targetProp), /*nested*/ true); } } } @@ -7684,14 +7703,18 @@ namespace ts { } } + function inferFromParameterTypes(source: Type, target: Type) { + return inferFromTypes(source, target, /*nested*/ true); + } + function inferFromSignature(source: Signature, target: Signature) { - forEachMatchingParameterType(source, target, inferFromTypes); + forEachMatchingParameterType(source, target, inferFromParameterTypes); if (source.typePredicate && target.typePredicate && source.typePredicate.kind === target.typePredicate.kind) { - inferFromTypes(source.typePredicate.type, target.typePredicate.type); + inferFromTypes(source.typePredicate.type, target.typePredicate.type, /*nested*/ true); } else { - inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target)); + inferFromTypes(getReturnTypeOfSignature(source), getReturnTypeOfSignature(target), /*nested*/ true); } } @@ -7701,7 +7724,7 @@ namespace ts { const sourceIndexType = getIndexTypeOfType(source, IndexKind.String) || getImplicitIndexTypeOfType(source, IndexKind.String); if (sourceIndexType) { - inferFromTypes(sourceIndexType, targetStringIndexType); + inferFromTypes(sourceIndexType, targetStringIndexType, /*nested*/ true); } } const targetNumberIndexType = getIndexTypeOfType(target, IndexKind.Number); @@ -7710,7 +7733,7 @@ namespace ts { getIndexTypeOfType(source, IndexKind.String) || getImplicitIndexTypeOfType(source, IndexKind.Number); if (sourceIndexType) { - inferFromTypes(sourceIndexType, targetNumberIndexType); + inferFromTypes(sourceIndexType, targetNumberIndexType, /*nested*/ true); } } } @@ -7744,14 +7767,32 @@ namespace ts { return inferences.primary || inferences.secondary || emptyArray; } + function hasPrimitiveConstraint(type: TypeParameter): boolean { + const constraint = getConstraintOfTypeParameter(type); + return constraint && maybeTypeOfKind(constraint, TypeFlags.Primitive); + } + + function hasTypeParameterAtTopLevel(type: Type, typeParameter: TypeParameter): boolean { + return type === typeParameter || type.flags & TypeFlags.UnionOrIntersection && forEach((type).types, t => hasTypeParameterAtTopLevel(t, typeParameter)); + } + + function getInferredType(context: InferenceContext, index: number): Type { let inferredType = context.inferredTypes[index]; let inferenceSucceeded: boolean; if (!inferredType) { const inferences = getInferenceCandidates(context, index); if (inferences.length) { + // We keep inferences of literal types if + // we made at least one inference that wasn't shallow, or + // the type parameter has a primitive type constraint, or + // the type parameter wasn't fixed and is referenced at top level in the return type. + const keepLiteralTypes = !context.inferences[index].shallow || + hasPrimitiveConstraint(context.typeParameters[index]) || + !context.inferences[index].isFixed && hasTypeParameterAtTopLevel(context.returnType, context.typeParameters[index]); + const baseInferences = keepLiteralTypes ? inferences : map(inferences, getBaseTypeOfLiteralType); // Infer widened union or supertype, or the unknown type for no common supertype - const unionOrSuperType = context.inferUnionTypes ? getUnionType(inferences, /*subtypeReduction*/ true) : getCommonSupertype(inferences); + const unionOrSuperType = context.inferUnionTypes ? getUnionType(baseInferences, /*subtypeReduction*/ true) : getCommonSupertype(baseInferences); inferredType = unionOrSuperType ? getWidenedType(unionOrSuperType) : unknownType; inferenceSucceeded = !!unionOrSuperType; } @@ -11203,7 +11244,7 @@ namespace ts { // Instantiate a generic signature in the context of a non-generic signature (section 3.8.5 in TypeScript spec) function instantiateSignatureInContextOf(signature: Signature, contextualSignature: Signature, contextualMapper: TypeMapper): Signature { - const context = createInferenceContext(signature.typeParameters, /*inferUnionTypes*/ true); + const context = createInferenceContext(signature, /*inferUnionTypes*/ true); forEachMatchingParameterType(contextualSignature, signature, (source, target) => { // Type parameters from outer context referenced by source type are fixed by instantiation of the source type inferTypes(context, instantiateType(source, contextualMapper), target); @@ -11859,7 +11900,7 @@ namespace ts { let candidate: Signature; let typeArgumentsAreValid: boolean; const inferenceContext = originalCandidate.typeParameters - ? createInferenceContext(originalCandidate.typeParameters, /*inferUnionTypes*/ false) + ? createInferenceContext(originalCandidate, /*inferUnionTypes*/ false) : undefined; while (true) { diff --git a/src/compiler/types.ts b/src/compiler/types.ts index 50beef52bf8..1cbd85d9830 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -2503,6 +2503,7 @@ namespace ts { export interface TypeInferences { primary: Type[]; // Inferences made directly to a type parameter secondary: Type[]; // Inferences made to a type parameter in a union type + shallow: boolean; // True if all inferences were made from shallow (not nested in object type) locations isFixed: boolean; // Whether the type parameter is fixed, as defined in section 4.12.2 of the TypeScript spec // If a type parameter is fixed, no more inferences can be made for the type parameter } @@ -2510,6 +2511,7 @@ namespace ts { /* @internal */ export interface InferenceContext { typeParameters: TypeParameter[]; // Type parameters for which inferences are made + returnType: Type; // Return type used when determining whether to widen literal types inferUnionTypes: boolean; // Infer union types for disjoint candidates (otherwise undefinedType) inferences: TypeInferences[]; // Inferences made for each type parameter inferredTypes: Type[]; // Inferred type for each type parameter