diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 65c95344c67..ae461119e8b 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -10129,6 +10129,10 @@ namespace ts { getUnionType(types, /*subtypeReduction*/ true); } + function getCommonSubtype(types: Type[]) { + return reduceLeft(types, (s, t) => isTypeSubtypeOf(t, s) ? t : s); + } + function isArrayType(type: Type): boolean { return getObjectFlags(type) & ObjectFlags.Reference && (type).target === globalArrayType; } @@ -10655,8 +10659,14 @@ namespace ts { const sourceTypes = (source).typeArguments || emptyArray; const targetTypes = (target).typeArguments || emptyArray; const count = sourceTypes.length < targetTypes.length ? sourceTypes.length : targetTypes.length; + const variances = strictFunctionTypes ? getVariances((source).target) : undefined; for (let i = 0; i < count; i++) { - inferFromTypes(sourceTypes[i], targetTypes[i]); + if (variances && i < variances.length && variances[i] === Variance.Contravariant) { + inferFromContravariantTypes(sourceTypes[i], targetTypes[i]); + } + else { + inferFromTypes(sourceTypes[i], targetTypes[i]); + } } } else if (source.flags & TypeFlags.Index && target.flags & TypeFlags.Index) { @@ -10727,6 +10737,17 @@ namespace ts { } } + function inferFromContravariantTypes(source: Type, target: Type) { + if (strictFunctionTypes) { + priority ^= InferencePriority.Contravariant; + inferFromTypes(source, target); + priority ^= InferencePriority.Contravariant; + } + else { + inferFromTypes(source, target); + } + } + function getInferenceInfoForType(type: Type) { if (type.flags & TypeFlags.TypeVariable) { for (const inference of inferences) { @@ -10804,7 +10825,7 @@ namespace ts { } function inferFromSignature(source: Signature, target: Signature) { - forEachMatchingParameterType(source, target, inferFromTypes); + forEachMatchingParameterType(source, target, inferFromContravariantTypes); if (source.typePredicate && target.typePredicate && source.typePredicate.kind === target.typePredicate.kind) { inferFromTypes(source.typePredicate.type, target.typePredicate.type); @@ -10879,8 +10900,9 @@ namespace ts { const baseCandidates = widenLiteralTypes ? sameMap(inference.candidates, getWidenedLiteralType) : inference.candidates; // Infer widened union or supertype, or the unknown type for no common supertype. We infer union types // for inferences coming from return types in order to avoid common supertype failures. - const unionOrSuperType = context.flags & InferenceFlags.InferUnionTypes || inference.priority & InferencePriority.ReturnType ? - getUnionType(baseCandidates, /*subtypeReduction*/ true) : getCommonSupertype(baseCandidates); + const unionOrSuperType = inference.priority & InferencePriority.Contravariant ? getCommonSubtype(baseCandidates) : + context.flags & InferenceFlags.InferUnionTypes || inference.priority & InferencePriority.ReturnType ? getUnionType(baseCandidates, /*subtypeReduction*/ true) : + getCommonSupertype(baseCandidates); inferredType = getWidenedType(unionOrSuperType); } else if (context.flags & InferenceFlags.NoDefault) { diff --git a/src/compiler/types.ts b/src/compiler/types.ts index a837722cc89..67974774ab5 100644 --- a/src/compiler/types.ts +++ b/src/compiler/types.ts @@ -3532,9 +3532,10 @@ namespace ts { } export const enum InferencePriority { - NakedTypeVariable = 1 << 0, // Naked type variable in union or intersection type - MappedType = 1 << 1, // Reverse inference for mapped type - ReturnType = 1 << 2, // Inference made from return type of generic function + Contravariant = 1 << 0, // Contravariant inference + NakedTypeVariable = 1 << 1, // Naked type variable in union or intersection type + MappedType = 1 << 2, // Reverse inference for mapped type + ReturnType = 1 << 3, // Inference made from return type of generic function } export interface InferenceInfo {