diff --git a/src/compiler/checker.ts b/src/compiler/checker.ts index 9d8ea29a381..93d9c19a71b 100644 --- a/src/compiler/checker.ts +++ b/src/compiler/checker.ts @@ -5715,19 +5715,37 @@ namespace ts { let type: ObjectType; if (isSpread) { let members: Map; + let stringIndexInfo: IndexInfo; + let numberIndexInfo: IndexInfo; const spreads: Type[] = []; for (const member of (node as TypeLiteralNode).members) { if (member.kind === SyntaxKind.SpreadTypeElement) { if (members) { - spreads.push(createAnonymousType(node.symbol, members, emptyArray, emptyArray, undefined, undefined)); + spreads.push(createAnonymousType(node.symbol, members, emptyArray, emptyArray, stringIndexInfo, numberIndexInfo)); members = undefined; + stringIndexInfo = undefined; + numberIndexInfo = undefined; } spreads.push(getTypeFromTypeNode((member as SpreadTypeElement).type)); } - else if (member.kind !== SyntaxKind.CallSignature && - member.kind !== SyntaxKind.ConstructSignature && - member.kind !== SyntaxKind.IndexSignature) { - // note that spread types don't include call and construct signatures, and index signatures are resolved later + else if (member.kind === SyntaxKind.IndexSignature) { + const index = member as IndexSignatureDeclaration; + if (index.parameters.length === 1) { + const parameter = index.parameters[0]; + if (parameter && parameter.type) { + const indexInfo = createIndexInfo(index.type ? getTypeFromTypeNode(index.type) : anyType, + (getModifierFlags(index) & ModifierFlags.Readonly) !== 0, index); + if (parameter.type.kind === SyntaxKind.StringKeyword) { + stringIndexInfo = indexInfo; + } + else { + numberIndexInfo = indexInfo; + } + } + } + } + else if (member.kind !== SyntaxKind.CallSignature && member.kind !== SyntaxKind.ConstructSignature) { + // note that spread types don't include call and construct signatures const flags = SymbolFlags.Property | SymbolFlags.Transient | (member.questionToken ? SymbolFlags.Optional : 0); const text = getTextOfPropertyName(member.name); const symbol = createSymbol(flags, text); @@ -5740,8 +5758,8 @@ namespace ts { members[symbol.name] = symbol; } } - if (members) { - spreads.push(createAnonymousType(node.symbol, members, emptyArray, emptyArray, undefined, undefined)); + if (members || stringIndexInfo || numberIndexInfo) { + spreads.push(createAnonymousType(node.symbol, members || emptySymbols, emptyArray, emptyArray, stringIndexInfo, numberIndexInfo)); } return getSpreadType(spreads, node.symbol, aliasSymbol, aliasTypeArguments); } @@ -5802,6 +5820,15 @@ namespace ts { // for types like T ... T, just return ... T return left; } + + if (right.flags & TypeFlags.ObjectType && + left.flags & TypeFlags.Spread && + (left as SpreadType).right.flags & TypeFlags.ObjectType) { + // simplify two adjacent object types: T ... { x } ... { y } becomes T ... { x, y } + // Note: left.left is always a spread type. Can we use this fact to avoid calling getSpreadType again? + return getSpreadType([getSpreadType([right, (left as SpreadType).right], symbol, aliasSymbol, aliasTypeArguments), + (left as SpreadType).left], symbol, aliasSymbol, aliasTypeArguments); + } if (left.flags & TypeFlags.Intersection) { const spreads = map((left as IntersectionType).types, t => getSpreadType(types.slice().concat([t, right]), symbol, aliasSymbol, aliasTypeArguments)); @@ -7991,7 +8018,8 @@ namespace ts { return !!(type.flags & TypeFlags.TypeParameter || type.flags & TypeFlags.Reference && forEach((type).typeArguments, couldContainTypeParameters) || type.flags & TypeFlags.Anonymous && type.symbol && type.symbol.flags & (SymbolFlags.Method | SymbolFlags.TypeLiteral | SymbolFlags.Class) || - type.flags & TypeFlags.UnionOrIntersection && couldUnionOrIntersectionContainTypeParameters(type)); + type.flags & TypeFlags.UnionOrIntersection && couldUnionOrIntersectionContainTypeParameters(type) || + type.flags & TypeFlags.Spread && couldSpreadContainTypeParameters(type as SpreadType)); } function couldUnionOrIntersectionContainTypeParameters(type: UnionOrIntersectionType): boolean { @@ -8001,6 +8029,11 @@ namespace ts { return type.couldContainTypeParameters; } + function couldSpreadContainTypeParameters(type: SpreadType): boolean { + return !!(type.right.flags & TypeFlags.TypeParameter || + type.left.flags & TypeFlags.Spread && (type.left as SpreadType).right.flags & TypeFlags.TypeParameter); + } + function isTypeParameterAtTopLevel(type: Type, typeParameter: TypeParameter): boolean { return type === typeParameter || type.flags & TypeFlags.UnionOrIntersection && forEach((type).types, t => isTypeParameterAtTopLevel(t, typeParameter)); } @@ -8064,6 +8097,16 @@ namespace ts { target = removeTypesFromUnionOrIntersection(target, matchingTypes); } } + if (source.flags & TypeFlags.Spread && target.flags & TypeFlags.Spread) { + // only the last type parameter is a valid inference site, + // and only if not followed by object literal properties. + if((source as SpreadType).right.flags & TypeFlags.TypeParameter && + (target as SpreadType).right.flags & TypeFlags.TypeParameter) { + inferFromTypes((source as SpreadType).right, (target as SpreadType).right); + } + + return; + } if (target.flags & TypeFlags.TypeParameter) { // If target is a type parameter, make an inference, unless the source type contains // the anyFunctionType (the wildcard type that's used to avoid contextually typing functions). @@ -8140,33 +8183,59 @@ namespace ts { else { source = getApparentType(source); if (source.flags & TypeFlags.ObjectType) { - if (isInProcess(source, target)) { - return; + if (target.flags & TypeFlags.Spread) { + // with an object type as source, a spread target infers to its last type parameter it + // contains, after removing any properties from a object type that precedes the type parameter + // Note that the call to `typeDifference` creates a new anonymous type. + const spread = target as SpreadType; + const parameter = spread.right.flags & TypeFlags.TypeParameter ? spread.right : (spread.left as SpreadType).right; + const object = spread.right.flags & TypeFlags.TypeParameter ? emptyObjectType : spread.right as ResolvedType; + inferFromTypes(getTypeDifference(source, object), parameter); + target = object; } - if (isDeeplyNestedGeneric(source, sourceStack, depth) && isDeeplyNestedGeneric(target, targetStack, depth)) { - return; - } - const key = source.id + "," + target.id; - if (visited[key]) { - return; - } - visited[key] = true; - if (depth === 0) { - sourceStack = []; - targetStack = []; - } - sourceStack[depth] = source; - targetStack[depth] = target; - depth++; - inferFromProperties(source, target); - inferFromSignatures(source, target, SignatureKind.Call); - inferFromSignatures(source, target, SignatureKind.Construct); - inferFromIndexTypes(source, target); - depth--; + inferFromStructure(source, target); } } } + function inferFromStructure(source: Type, target: Type) { + if (isInProcess(source, target)) { + return; + } + if (isDeeplyNestedGeneric(source, sourceStack, depth) && isDeeplyNestedGeneric(target, targetStack, depth)) { + return; + } + const key = source.id + "," + target.id; + if (visited[key]) { + return; + } + visited[key] = true; + if (depth === 0) { + sourceStack = []; + targetStack = []; + } + sourceStack[depth] = source; + targetStack[depth] = target; + depth++; + inferFromProperties(source, target); + inferFromSignatures(source, target, SignatureKind.Call); + inferFromSignatures(source, target, SignatureKind.Construct); + inferFromIndexTypes(source, target); + depth--; + } + + function getTypeDifference(type: ObjectType, diff: ResolvedType): ResolvedType { + const members = createMap(); + for (const prop of getPropertiesOfObjectType(type)) { + if (!(prop.name in diff.members)) { + members[prop.name] = prop; + } + } + const stringIndexInfo = getIndexInfoOfType(diff, IndexKind.String) ? undefined : getIndexInfoOfType(type, IndexKind.String); + const numberIndexInfo = getIndexInfoOfType(diff, IndexKind.Number) ? undefined : getIndexInfoOfType(type, IndexKind.Number); + return createAnonymousType(type.symbol, members, emptyArray, emptyArray, stringIndexInfo, numberIndexInfo); + } + function inferFromProperties(source: Type, target: Type) { const properties = getPropertiesOfObjectType(target); for (const targetProp of properties) {