diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 6d5a8b74cb3..f7d01d5480d 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -8425,7 +8425,7 @@ namespace ts { // results for union and intersection types for performance reasons. function couldContainTypeParameters(type: Type): boolean { const objectFlags = getObjectFlags(type); - return !!(type.flags & TypeFlags.TypeParameter || + return !!(type.flags & (TypeFlags.TypeParameter | TypeFlags.IndexedAccess) || objectFlags & ObjectFlags.Reference && forEach((type).typeArguments, couldContainTypeParameters) || objectFlags & ObjectFlags.Anonymous && type.symbol && type.symbol.flags & (SymbolFlags.Method | SymbolFlags.TypeLiteral | SymbolFlags.Class) || objectFlags & ObjectFlags.Mapped || @@ -8443,8 +8443,57 @@ namespace ts { return type === typeParameter || type.flags & TypeFlags.UnionOrIntersection && forEach((type).types, t => isTypeParameterAtTopLevel(t, typeParameter)); } - function inferTypes(context: InferenceContext, originalSource: Type, originalTarget: Type) { - const typeParameters = context.signature.typeParameters; + // Infer a suitable input type for an isomorphic mapped type { [P in keyof T]: X }. We construct + // an object type with the same set of properties as the source type, where the type of each + // property is computed by inferring from the source property type to X for a synthetic type + // parameter T[P] (i.e. we treat the type T[P] as the type parameter we're inferring for). + function inferTypeForIsomorphicMappedType(source: Type, target: MappedType): Type { + if (!isMappableType(source)) { + return source; + } + const typeParameter = getIndexedAccessType((getConstraintTypeFromMappedType(target)).type, getTypeParameterFromMappedType(target)); + const typeParameterArray = [typeParameter]; + const typeInferences = createTypeInferencesObject(); + const typeInferencesArray = [typeInferences]; + const templateType = getTemplateTypeFromMappedType(target); + const properties = getPropertiesOfType(source); + const members = createSymbolTable(properties); + let hasInferredTypes = false; + for (const prop of properties) { + const inferredPropType = inferTargetType(getTypeOfSymbol(prop)); + if (inferredPropType) { + const inferredProp = createSymbol(SymbolFlags.Property | SymbolFlags.Transient | prop.flags & SymbolFlags.Optional, prop.name); + inferredProp.declarations = prop.declarations; + inferredProp.type = inferredPropType; + inferredProp.isReadonly = isReadonlySymbol(prop); + members[prop.name] = inferredProp; + hasInferredTypes = true; + } + } + let indexInfo = getIndexInfoOfType(source, IndexKind.String); + if (indexInfo) { + const inferredIndexType = inferTargetType(indexInfo.type); + if (inferredIndexType) { + indexInfo = createIndexInfo(inferredIndexType, indexInfo.isReadonly); + hasInferredTypes = true; + } + } + return hasInferredTypes ? createAnonymousType(undefined, members, emptyArray, emptyArray, indexInfo, undefined) : source; + + function inferTargetType(sourceType: Type): Type { + typeInferences.primary = undefined; + typeInferences.secondary = undefined; + inferTypes(typeParameterArray, typeInferencesArray, sourceType, templateType); + const inferences = typeInferences.primary || typeInferences.secondary; + return inferences && getUnionType(inferences, /*subtypeReduction*/ true); + } + } + + function inferTypesWithContext(context: InferenceContext, originalSource: Type, originalTarget: Type) { + inferTypes(context.signature.typeParameters, context.inferences, originalSource, originalTarget); + } + + function inferTypes(typeParameters: Type[], typeInferences: TypeInferences[], originalSource: Type, originalTarget: Type) { let sourceStack: Type[]; let targetStack: Type[]; let depth = 0; @@ -8512,7 +8561,7 @@ namespace ts { target = removeTypesFromUnionOrIntersection(target, matchingTypes); } } - if (target.flags & TypeFlags.TypeParameter) { + if (target.flags & (TypeFlags.TypeParameter | TypeFlags.IndexedAccess)) { // If target is a type parameter, make an inference, unless the source type contains // the anyFunctionType (the wildcard type that's used to avoid contextually typing functions). // Because the anyFunctionType is internal, it should not be exposed to the user by adding @@ -8524,7 +8573,7 @@ namespace ts { } for (let i = 0; i < typeParameters.length; i++) { if (target === typeParameters[i]) { - const inferences = context.inferences[i]; + const inferences = typeInferences[i]; if (!inferences.isFixed) { // Any inferences that are made to a type parameter in a union type are inferior // to inferences made to a flat (non-union) type. This is because if we infer to @@ -8538,7 +8587,7 @@ namespace ts { if (!contains(candidates, source)) { candidates.push(source); } - if (!isTypeParameterAtTopLevel(originalTarget, target)) { + if (target.flags & TypeFlags.TypeParameter && !isTypeParameterAtTopLevel(originalTarget, target)) { inferences.topLevel = false; } } @@ -8589,15 +8638,29 @@ namespace ts { if (getObjectFlags(target) & ObjectFlags.Mapped) { const constraintType = getConstraintTypeFromMappedType(target); if (getObjectFlags(source) & ObjectFlags.Mapped) { + // We're inferring from a mapped type to a mapped type, so simply infer from constraint type to + // constraint type and from template type to template type. inferFromTypes(getConstraintTypeFromMappedType(source), constraintType); inferFromTypes(getTemplateTypeFromMappedType(source), getTemplateTypeFromMappedType(target)); return; } if (constraintType.flags & TypeFlags.TypeParameter) { + // We're inferring from some source type S to a mapped type { [P in T]: X }, where T is a type + // parameter. Infer from 'keyof S' to T and infer from a union of each property type in S to X. inferFromTypes(getIndexType(source), constraintType); inferFromTypes(getUnionType(map(getPropertiesOfType(source), getTypeOfSymbol)), getTemplateTypeFromMappedType(target)); return; } + if (constraintType.flags & TypeFlags.Index) { + // We're inferring from some source type S to an isomorphic mapped type { [P in keyof T]: X }, + // where T is a type parameter. Use inferTypeForIsomorphicMappedType to infer a suitable source + // type and then infer from that type to T. + const index = indexOf(typeParameters, (constraintType).type); + if (index >= 0 && !typeInferences[index].isFixed) { + inferFromTypes(inferTypeForIsomorphicMappedType(source, target), typeParameters[index]); + } + return; + } } source = getApparentType(source); if (source.flags & TypeFlags.Object) { @@ -12458,7 +12521,7 @@ namespace ts { const context = createInferenceContext(signature, /*inferUnionTypes*/ true); forEachMatchingParameterType(contextualSignature, signature, (source, target) => { // Type parameters from outer context referenced by source type are fixed by instantiation of the source type - inferTypes(context, instantiateType(source, contextualMapper), target); + inferTypesWithContext(context, instantiateType(source, contextualMapper), target); }); return getSignatureInstantiation(signature, getInferredTypes(context)); } @@ -12493,7 +12556,7 @@ namespace ts { if (thisType) { const thisArgumentNode = getThisArgumentOfCall(node); const thisArgumentType = thisArgumentNode ? checkExpression(thisArgumentNode) : voidType; - inferTypes(context, thisArgumentType, thisType); + inferTypesWithContext(context, thisArgumentType, thisType); } // We perform two passes over the arguments. In the first pass we infer from all arguments, but use @@ -12515,7 +12578,7 @@ namespace ts { argType = checkExpressionWithContextualType(arg, paramType, mapper); } - inferTypes(context, argType, paramType); + inferTypesWithContext(context, argType, paramType); } } @@ -12530,7 +12593,7 @@ namespace ts { if (excludeArgument[i] === false) { const arg = args[i]; const paramType = getTypeAtPosition(signature, i); - inferTypes(context, checkExpressionWithContextualType(arg, paramType, inferenceMapper), paramType); + inferTypesWithContext(context, checkExpressionWithContextualType(arg, paramType, inferenceMapper), paramType); } } } @@ -13617,7 +13680,7 @@ namespace ts { for (let i = 0; i < len; i++) { const declaration = signature.parameters[i].valueDeclaration; if (declaration.type) { - inferTypes(mapper.context, getTypeFromTypeNode(declaration.type), getTypeAtPosition(context, i)); + inferTypesWithContext(mapper.context, getTypeFromTypeNode(declaration.type), getTypeAtPosition(context, i)); } } } @@ -13703,7 +13766,7 @@ namespace ts { // T in the second overload so that we do not infer Base as a candidate for T // (inferring Base would make type argument inference inconsistent between the two // overloads). - inferTypes(mapper.context, links.type, instantiateType(contextualType, mapper)); + inferTypesWithContext(mapper.context, links.type, instantiateType(contextualType, mapper)); } }