mirror of
https://github.com/microsoft/TypeScript.git
synced 2026-06-11 20:37:46 -05:00
Merge pull request #32460 from microsoft/fix32434
Improve type inference for types like 'T | Promise<T>'
This commit is contained in:
@@ -13540,6 +13540,9 @@ namespace ts {
|
||||
if (relation !== identityRelation) {
|
||||
source = getApparentType(source);
|
||||
}
|
||||
else if (isGenericMappedType(source)) {
|
||||
return Ternary.False;
|
||||
}
|
||||
if (getObjectFlags(source) & ObjectFlags.Reference && getObjectFlags(target) & ObjectFlags.Reference && (<TypeReference>source).target === (<TypeReference>target).target &&
|
||||
!(getObjectFlags(source) & ObjectFlags.MarkerType || getObjectFlags(target) & ObjectFlags.MarkerType)) {
|
||||
// We have type references to the same generic type, and the type references are not marker
|
||||
@@ -15456,9 +15459,11 @@ namespace ts {
|
||||
|
||||
function inferTypes(inferences: InferenceInfo[], originalSource: Type, originalTarget: Type, priority: InferencePriority = 0, contravariant = false) {
|
||||
let symbolStack: Symbol[];
|
||||
let visited: Map<boolean>;
|
||||
let visited: Map<number>;
|
||||
let bivariant = false;
|
||||
let propagationType: Type;
|
||||
let inferenceCount = 0;
|
||||
let inferenceIncomplete = false;
|
||||
let allowComplexConstraintInference = true;
|
||||
inferFromTypes(originalSource, originalTarget);
|
||||
|
||||
@@ -15500,23 +15505,28 @@ namespace ts {
|
||||
// of all their possible values.
|
||||
let matchingTypes: Type[] | undefined;
|
||||
for (const t of (<UnionOrIntersectionType>source).types) {
|
||||
if (typeIdenticalToSomeType(t, (<UnionOrIntersectionType>target).types)) {
|
||||
(matchingTypes || (matchingTypes = [])).push(t);
|
||||
inferFromTypes(t, t);
|
||||
}
|
||||
else if (t.flags & (TypeFlags.NumberLiteral | TypeFlags.StringLiteral)) {
|
||||
const b = getBaseTypeOfLiteralType(t);
|
||||
if (typeIdenticalToSomeType(b, (<UnionOrIntersectionType>target).types)) {
|
||||
(matchingTypes || (matchingTypes = [])).push(t, b);
|
||||
}
|
||||
const matched = findMatchedType(t, <UnionOrIntersectionType>target);
|
||||
if (matched) {
|
||||
(matchingTypes || (matchingTypes = [])).push(matched);
|
||||
inferFromTypes(matched, matched);
|
||||
}
|
||||
}
|
||||
// Next, to improve the quality of inferences, reduce the source and target types by
|
||||
// removing the identically matched constituents. For example, when inferring from
|
||||
// 'string | string[]' to 'string | T' we reduce the types to 'string[]' and 'T'.
|
||||
if (matchingTypes) {
|
||||
source = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>source, matchingTypes);
|
||||
target = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>target, matchingTypes);
|
||||
const s = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>source, matchingTypes);
|
||||
const t = removeTypesFromUnionOrIntersection(<UnionOrIntersectionType>target, matchingTypes);
|
||||
if (!(s && t)) return;
|
||||
source = s;
|
||||
target = t;
|
||||
}
|
||||
}
|
||||
else if (target.flags & TypeFlags.Union && !(target.flags & TypeFlags.EnumLiteral) || target.flags & TypeFlags.Intersection) {
|
||||
const matched = findMatchedType(source, <UnionOrIntersectionType>target);
|
||||
if (matched) {
|
||||
inferFromTypes(matched, matched);
|
||||
return;
|
||||
}
|
||||
}
|
||||
else if (target.flags & (TypeFlags.IndexedAccess | TypeFlags.Substitution)) {
|
||||
@@ -15562,13 +15572,14 @@ namespace ts {
|
||||
clearCachedInferences(inferences);
|
||||
}
|
||||
}
|
||||
inferenceCount++;
|
||||
return;
|
||||
}
|
||||
else {
|
||||
// Infer to the simplified version of an indexed access, if possible, to (hopefully) expose more bare type parameters to the inference engine
|
||||
const simplified = getSimplifiedType(target, /*writing*/ false);
|
||||
if (simplified !== target) {
|
||||
inferFromTypesOnce(source, simplified);
|
||||
invokeOnce(source, simplified, inferFromTypes);
|
||||
}
|
||||
else if (target.flags & TypeFlags.IndexedAccess) {
|
||||
const indexType = getSimplifiedType((target as IndexedAccessType).indexType, /*writing*/ false);
|
||||
@@ -15577,13 +15588,14 @@ namespace ts {
|
||||
if (indexType.flags & TypeFlags.Instantiable) {
|
||||
const simplified = distributeIndexOverObjectType(getSimplifiedType((target as IndexedAccessType).objectType, /*writing*/ false), indexType, /*writing*/ false);
|
||||
if (simplified && simplified !== target) {
|
||||
inferFromTypesOnce(source, simplified);
|
||||
invokeOnce(source, simplified, inferFromTypes);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (getObjectFlags(source) & ObjectFlags.Reference && getObjectFlags(target) & ObjectFlags.Reference && (<TypeReference>source).target === (<TypeReference>target).target) {
|
||||
if (getObjectFlags(source) & ObjectFlags.Reference && getObjectFlags(target) & ObjectFlags.Reference && (
|
||||
(<TypeReference>source).target === (<TypeReference>target).target || isArrayType(source) && isArrayType(target))) {
|
||||
// If source and target are references to the same generic type, infer from type arguments
|
||||
inferFromTypeArguments((<TypeReference>source).typeArguments || emptyArray, (<TypeReference>target).typeArguments || emptyArray, getVariances((<TypeReference>source).target));
|
||||
}
|
||||
@@ -15613,10 +15625,10 @@ namespace ts {
|
||||
}
|
||||
else if (target.flags & TypeFlags.Conditional && !contravariant) {
|
||||
const targetTypes = [getTrueTypeFromConditionalType(<ConditionalType>target), getFalseTypeFromConditionalType(<ConditionalType>target)];
|
||||
inferToMultipleTypes(source, targetTypes, /*isIntersection*/ false);
|
||||
inferToMultipleTypes(source, targetTypes, target.flags);
|
||||
}
|
||||
else if (target.flags & TypeFlags.UnionOrIntersection) {
|
||||
inferToMultipleTypes(source, (<UnionOrIntersectionType>target).types, !!(target.flags & TypeFlags.Intersection));
|
||||
inferToMultipleTypes(source, (<UnionOrIntersectionType>target).types, target.flags);
|
||||
}
|
||||
else if (source.flags & TypeFlags.Union) {
|
||||
// Source is a union or intersection type, infer from each constituent type
|
||||
@@ -15645,39 +15657,22 @@ namespace ts {
|
||||
source = apparentSource;
|
||||
}
|
||||
if (source.flags & (TypeFlags.Object | TypeFlags.Intersection)) {
|
||||
const key = source.id + "," + target.id;
|
||||
if (visited && visited.get(key)) {
|
||||
return;
|
||||
}
|
||||
(visited || (visited = createMap<boolean>())).set(key, true);
|
||||
// If we are already processing another target type with the same associated symbol (such as
|
||||
// an instantiation of the same generic type), we do not explore this target as it would yield
|
||||
// no further inferences. We exclude the static side of classes from this check since it shares
|
||||
// its symbol with the instance side which would lead to false positives.
|
||||
const isNonConstructorObject = target.flags & TypeFlags.Object &&
|
||||
!(getObjectFlags(target) & ObjectFlags.Anonymous && target.symbol && target.symbol.flags & SymbolFlags.Class);
|
||||
const symbol = isNonConstructorObject ? target.symbol : undefined;
|
||||
if (symbol) {
|
||||
if (contains(symbolStack, symbol)) {
|
||||
return;
|
||||
}
|
||||
(symbolStack || (symbolStack = [])).push(symbol);
|
||||
inferFromObjectTypes(source, target);
|
||||
symbolStack.pop();
|
||||
}
|
||||
else {
|
||||
inferFromObjectTypes(source, target);
|
||||
}
|
||||
invokeOnce(source, target, inferFromObjectTypes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function inferFromTypesOnce(source: Type, target: Type) {
|
||||
const key = source.id + "," + target.id;
|
||||
if (!visited || !visited.get(key)) {
|
||||
(visited || (visited = createMap<boolean>())).set(key, true);
|
||||
inferFromTypes(source, target);
|
||||
}
|
||||
function invokeOnce(source: Type, target: Type, action: (source: Type, target: Type) => void) {
|
||||
const key = source.id + "," + target.id;
|
||||
const count = visited && visited.get(key);
|
||||
if (count !== undefined) {
|
||||
inferenceCount += count;
|
||||
return;
|
||||
}
|
||||
(visited || (visited = createMap<number>())).set(key, 0);
|
||||
const startCount = inferenceCount;
|
||||
action(source, target);
|
||||
visited.set(key, inferenceCount - startCount);
|
||||
}
|
||||
|
||||
function inferFromTypeArguments(sourceTypes: readonly Type[], targetTypes: readonly Type[], variances: readonly VarianceFlags[]) {
|
||||
@@ -15714,24 +15709,60 @@ namespace ts {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function inferToMultipleTypes(source: Type, targets: Type[], isIntersection: boolean) {
|
||||
// We infer from types that are not naked type variables first so that inferences we
|
||||
// make from nested naked type variables and given slightly higher priority by virtue
|
||||
// of being first in the candidates array.
|
||||
function inferToMultipleTypes(source: Type, targets: Type[], targetFlags: TypeFlags) {
|
||||
let typeVariableCount = 0;
|
||||
for (const t of targets) {
|
||||
if (getInferenceInfoForType(t)) {
|
||||
typeVariableCount++;
|
||||
if (targetFlags & TypeFlags.Union) {
|
||||
let nakedTypeVariable: Type | undefined;
|
||||
const sources = source.flags & TypeFlags.Union ? (<UnionType>source).types : [source];
|
||||
const matched = new Array<boolean>(sources.length);
|
||||
const saveInferenceIncomplete = inferenceIncomplete;
|
||||
inferenceIncomplete = false;
|
||||
// First infer to types that are not naked type variables. For each source type we
|
||||
// track whether inferences were made from that particular type to some target.
|
||||
for (const t of targets) {
|
||||
if (getInferenceInfoForType(t)) {
|
||||
nakedTypeVariable = t;
|
||||
typeVariableCount++;
|
||||
}
|
||||
else {
|
||||
for (let i = 0; i < sources.length; i++) {
|
||||
const count = inferenceCount;
|
||||
inferFromTypes(sources[i], t);
|
||||
if (count !== inferenceCount) matched[i] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
inferFromTypes(source, t);
|
||||
const inferenceComplete = !inferenceIncomplete;
|
||||
inferenceIncomplete = inferenceIncomplete || saveInferenceIncomplete;
|
||||
// If the target has a single naked type variable and inference completed (meaning we
|
||||
// explored the types fully), create a union of the source types from which no inferences
|
||||
// have been made so far and infer from that union to the naked type variable.
|
||||
if (typeVariableCount === 1 && inferenceComplete) {
|
||||
const unmatched = flatMap(sources, (s, i) => matched[i] ? undefined : s);
|
||||
if (unmatched.length) {
|
||||
inferFromTypes(getUnionType(unmatched), nakedTypeVariable!);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// We infer from types that are not naked type variables first so that inferences we
|
||||
// make from nested naked type variables and given slightly higher priority by virtue
|
||||
// of being first in the candidates array.
|
||||
for (const t of targets) {
|
||||
if (getInferenceInfoForType(t)) {
|
||||
typeVariableCount++;
|
||||
}
|
||||
else {
|
||||
inferFromTypes(source, t);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Inferences directly to naked type variables are given lower priority as they are
|
||||
// less specific. For example, when inferring from Promise<string> to T | Promise<T>,
|
||||
// we want to infer string for T, not Promise<string> | string. For intersection types
|
||||
// we only infer to single naked type variables.
|
||||
if (isIntersection ? typeVariableCount === 1 : typeVariableCount !== 0) {
|
||||
if (targetFlags & TypeFlags.Intersection ? typeVariableCount === 1 : typeVariableCount > 0) {
|
||||
const savePriority = priority;
|
||||
priority |= InferencePriority.NakedTypeVariable;
|
||||
for (const t of targets) {
|
||||
@@ -15800,6 +15831,28 @@ namespace ts {
|
||||
}
|
||||
|
||||
function inferFromObjectTypes(source: Type, target: Type) {
|
||||
// If we are already processing another target type with the same associated symbol (such as
|
||||
// an instantiation of the same generic type), we do not explore this target as it would yield
|
||||
// no further inferences. We exclude the static side of classes from this check since it shares
|
||||
// its symbol with the instance side which would lead to false positives.
|
||||
const isNonConstructorObject = target.flags & TypeFlags.Object &&
|
||||
!(getObjectFlags(target) & ObjectFlags.Anonymous && target.symbol && target.symbol.flags & SymbolFlags.Class);
|
||||
const symbol = isNonConstructorObject ? target.symbol : undefined;
|
||||
if (symbol) {
|
||||
if (contains(symbolStack, symbol)) {
|
||||
inferenceIncomplete = true;
|
||||
return;
|
||||
}
|
||||
(symbolStack || (symbolStack = [])).push(symbol);
|
||||
inferFromObjectTypesWorker(source, target);
|
||||
symbolStack.pop();
|
||||
}
|
||||
else {
|
||||
inferFromObjectTypesWorker(source, target);
|
||||
}
|
||||
}
|
||||
|
||||
function inferFromObjectTypesWorker(source: Type, target: Type) {
|
||||
if (isGenericMappedType(source) && isGenericMappedType(target)) {
|
||||
// The source and target types are generic types { [P in S]: X } and { [P in T]: Y }, so we infer
|
||||
// from S to T and from X to Y.
|
||||
@@ -15902,15 +15955,35 @@ namespace ts {
|
||||
}
|
||||
}
|
||||
|
||||
function typeIdenticalToSomeType(type: Type, types: Type[]): boolean {
|
||||
function isMatchableType(type: Type) {
|
||||
// We exclude non-anonymous object types because some frameworks (e.g. Ember) rely on the ability to
|
||||
// infer between types that don't witness their type variables. Such types would otherwise be eliminated
|
||||
// because they appear identical.
|
||||
return !(type.flags & TypeFlags.Object) || !!(getObjectFlags(type) & ObjectFlags.Anonymous);
|
||||
}
|
||||
|
||||
function typeMatchedBySomeType(type: Type, types: Type[]): boolean {
|
||||
for (const t of types) {
|
||||
if (isTypeIdenticalTo(t, type)) {
|
||||
if (t === type || isMatchableType(t) && isMatchableType(type) && isTypeIdenticalTo(t, type)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function findMatchedType(type: Type, target: UnionOrIntersectionType) {
|
||||
if (typeMatchedBySomeType(type, target.types)) {
|
||||
return type;
|
||||
}
|
||||
if (type.flags & (TypeFlags.NumberLiteral | TypeFlags.StringLiteral) && target.flags & TypeFlags.Union) {
|
||||
const base = getBaseTypeOfLiteralType(type);
|
||||
if (typeMatchedBySomeType(base, target.types)) {
|
||||
return base;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new union or intersection type computed by removing a given set of types
|
||||
* from a given union or intersection type.
|
||||
@@ -15918,11 +15991,11 @@ namespace ts {
|
||||
function removeTypesFromUnionOrIntersection(type: UnionOrIntersectionType, typesToRemove: Type[]) {
|
||||
const reducedTypes: Type[] = [];
|
||||
for (const t of type.types) {
|
||||
if (!typeIdenticalToSomeType(t, typesToRemove)) {
|
||||
if (!typeMatchedBySomeType(t, typesToRemove)) {
|
||||
reducedTypes.push(t);
|
||||
}
|
||||
}
|
||||
return type.flags & TypeFlags.Union ? getUnionType(reducedTypes) : getIntersectionType(reducedTypes);
|
||||
return reducedTypes.length ? type.flags & TypeFlags.Union ? getUnionType(reducedTypes) : getIntersectionType(reducedTypes) : undefined;
|
||||
}
|
||||
|
||||
function hasPrimitiveConstraint(type: TypeParameter): boolean {
|
||||
|
||||
Reference in New Issue
Block a user