diff --git a/src/services/refactors/convertToNamedParameters.ts b/src/services/refactors/convertToNamedParameters.ts index fa285619dab..f0873a56baf 100644 --- a/src/services/refactors/convertToNamedParameters.ts +++ b/src/services/refactors/convertToNamedParameters.ts @@ -4,14 +4,14 @@ namespace ts.refactor.convertToNamedParameters { const refactorDescription = "Convert to named parameters"; const actionNameNamedParameters = "Convert to named parameters"; const actionDescriptionNamedParameters = "Convert to named parameters"; - const minimumParameterLength = 1; + const minimumParameterLength = 2; registerRefactor(refactorName, { getEditsForAction, getAvailableActions }); function getAvailableActions(context: RefactorContext): ReadonlyArray { const { file, startPosition } = context; const isJSFile = isSourceFileJS(file); - if (isJSFile) return emptyArray; + if (isJSFile) return emptyArray; // TODO: GH#30113 const functionDeclaration = getFunctionDeclarationAtPosition(file, startPosition, context.program.getTypeChecker()); if (!functionDeclaration) return emptyArray; @@ -37,7 +37,7 @@ namespace ts.refactor.convertToNamedParameters { return { renameFilename: undefined, renameLocation: undefined, edits }; } - return { edits: [] }; + return { edits: [] }; // TODO: GH#30113 } function doChange(sourceFile: SourceFile, program: Program, host: LanguageServiceHost, changes: textChanges.ChangeTracker, functionDeclaration: ValidFunctionDeclaration, groupedReferences: GroupedReferences): void { @@ -55,7 +55,7 @@ namespace ts.refactor.convertToNamedParameters { }); - const functionCalls = deduplicate(groupedReferences.functionCalls, (a, b) => a === b); + const functionCalls = deduplicate(groupedReferences.functionCalls, equateValues); forEach(functionCalls, call => { if (call.arguments && call.arguments.length) { const newArgument = getSynthesizedDeepClone(createNewArgument(functionDeclaration, call.arguments), /*includeTrivia*/ true); @@ -69,60 +69,57 @@ namespace ts.refactor.convertToNamedParameters { } function getGroupedReferences(functionDeclaration: ValidFunctionDeclaration, program: Program, cancellationToken: CancellationToken): GroupedReferences { - const names = getDeclarationNames(functionDeclaration); - const references = flatMap(names, name => FindAllReferences.getReferenceEntriesForNode(-1, name, program, program.getSourceFiles(), cancellationToken)); - let groupedReferences = groupReferences(references); + const functionNames = getFunctionNames(functionDeclaration); + const classNames = isConstructorDeclaration(functionDeclaration) ? getClassNames(functionDeclaration) : []; + const names = deduplicate([...functionNames, ...classNames], equateValues); + const checker = program.getTypeChecker(); - // if the refactored function is a constructor, we must also go through the references to its class - if (isConstructorDeclaration(functionDeclaration)) { - const className = getClassName(functionDeclaration); - groupedReferences = groupClassReferences(groupedReferences, className); + const references = flatMap(names, name => FindAllReferences.getReferenceEntriesForNode(-1, name, program, program.getSourceFiles(), cancellationToken)); + const isConstructor = isConstructorDeclaration(functionDeclaration); + const groupedReferences = groupReferences(references, isConstructor); + + if (!every(groupedReferences.declarations, decl => contains(names, decl))) { + groupedReferences.valid = false; } - validateReferences(groupedReferences); return groupedReferences; - function getClassName(constructorDeclaration: ValidConstructor): Identifier { - switch (constructorDeclaration.parent.kind) { - case SyntaxKind.ClassDeclaration: - return constructorDeclaration.parent.name; - case SyntaxKind.ClassExpression: - return constructorDeclaration.parent.parent.name; - } - } - - function groupReferences(referenceEntries: ReadonlyArray | undefined): GroupedReferences { - const groupedReferences: GroupedReferences = { functionCalls: [], declarations: [], unhandled: [], valid: true }; - - forEach(referenceEntries, (entry) => { - const decl = entryToDeclaration(entry); - if (decl) { - groupedReferences.declarations.push(decl); - return; - } - - const call = entryToFunctionCall(entry); - if (call) { - groupedReferences.functionCalls.push(call); - return; - } - - groupedReferences.unhandled.push(entry); - }); - return groupedReferences; - } - - function groupClassReferences(groupedReferences: GroupedReferences, className: Identifier): GroupedReferences { + function groupReferences(referenceEntries: ReadonlyArray, isConstructor: boolean): GroupedReferences { const classReferences: ClassReferences = { accessExpressions: [], typeUsages: [] }; - const unhandledEntries = groupedReferences.unhandled; - const newUnhandledEntries: FindAllReferences.Entry[] = []; + const groupedReferences: GroupedReferences = { functionCalls: [], declarations: [], classReferences, valid: true }; + const functionSymbols = map(functionNames, checker.getSymbolAtLocation); + const classSymbols = map(classNames, checker.getSymbolAtLocation); + + for (const entry of referenceEntries) { + if (entry.kind !== FindAllReferences.EntryKind.Node) { + groupedReferences.valid = false; + continue; + } + if (contains(functionSymbols, checker.getSymbolAtLocation(entry.node), symbolComparer)) { + const decl = entryToDeclaration(entry); + if (decl) { + groupedReferences.declarations.push(decl); + continue; + } + + const call = entryToFunctionCall(entry); + if (call) { + groupedReferences.functionCalls.push(call); + continue; + } + } + // if the refactored function is a constructor, we must also check if the references to its class are valid + if (isConstructor && contains(classSymbols, checker.getSymbolAtLocation(entry.node), symbolComparer)) { + const decl = entryToDeclaration(entry); + if (decl) { + groupedReferences.declarations.push(decl); + continue; + } - forEach(unhandledEntries, (entry) => { - if (entry.kind === FindAllReferences.EntryKind.Node && entry.node.symbol === className.symbol) { const accessExpression = entryToAccessExpression(entry); if (accessExpression) { classReferences.accessExpressions.push(accessExpression); - return; + continue; } // Only class declarations are allowed to be used as a type (in a heritage clause), @@ -131,27 +128,29 @@ namespace ts.refactor.convertToNamedParameters { const type = entryToType(entry); if (type) { classReferences.typeUsages.push(type); - return; + continue; } } } - newUnhandledEntries.push(entry); - }); - - return { ...groupedReferences, classReferences, unhandled: newUnhandledEntries }; - } - - function validateReferences(groupedReferences: GroupedReferences): void { - if (groupedReferences.unhandled.length > 0) { - groupedReferences.valid = false; - } - if (!every(groupedReferences.declarations, decl => contains(names, decl))) { groupedReferences.valid = false; } + + return groupedReferences; } - function entryToFunctionCall(entry: FindAllReferences.Entry): CallExpression | NewExpression | undefined { - if (entry.kind === FindAllReferences.EntryKind.Node && entry.node.parent) { + function symbolComparer(a: Symbol, b: Symbol): boolean { + return getSymbolTarget(a) === getSymbolTarget(b); + } + + function entryToDeclaration(entry: FindAllReferences.NodeEntry): Node | undefined { + if (isDeclaration(entry.node.parent)) { + return entry.node; + } + return undefined; + } + + function entryToFunctionCall(entry: FindAllReferences.NodeEntry): CallExpression | NewExpression | undefined { + if (entry.node.parent) { const functionReference = entry.node; const parent = functionReference.parent; switch (parent.kind) { @@ -194,15 +193,8 @@ namespace ts.refactor.convertToNamedParameters { return undefined; } - function entryToDeclaration(entry: FindAllReferences.Entry): Node | undefined { - if (entry.kind === FindAllReferences.EntryKind.Node && contains(names, entry.node)) { - return entry.node; - } - return undefined; - } - - function entryToAccessExpression(entry: FindAllReferences.Entry): ElementAccessExpression | PropertyAccessExpression | undefined { - if (entry.kind === FindAllReferences.EntryKind.Node && entry.node.parent) { + function entryToAccessExpression(entry: FindAllReferences.NodeEntry): ElementAccessExpression | PropertyAccessExpression | undefined { + if (entry.node.parent) { const reference = entry.node; const parent = reference.parent; switch (parent.kind) { @@ -263,7 +255,7 @@ namespace ts.refactor.convertToNamedParameters { return false; function isValidParameterNodeArray(parameters: NodeArray): parameters is ValidParameterNodeArray { - return getRefactorableParametersLength(parameters) > minimumParameterLength && every(parameters, isValidParameterDeclaration); + return getRefactorableParametersLength(parameters) >= minimumParameterLength && every(parameters, isValidParameterDeclaration); } function isValidParameterDeclaration(paramDeclaration: ParameterDeclaration): paramDeclaration is ValidParameterDeclaration { @@ -271,7 +263,7 @@ namespace ts.refactor.convertToNamedParameters { } function isValidVariableDeclaration(node: Node): node is ValidVariableDeclaration { - return isVariableDeclaration(node) && isVarConst(node) && isIdentifier(node.name) && !node.type; + return isVariableDeclaration(node) && isVarConst(node) && isIdentifier(node.name) && !node.type; // TODO: GH#30113 } } @@ -430,25 +422,32 @@ namespace ts.refactor.convertToNamedParameters { return getTextOfIdentifierOrLiteral(paramDeclaration.name); } - function getDeclarationNames(functionDeclaration: ValidFunctionDeclaration): Node[] { + function getClassNames(constructorDeclaration: ValidConstructor): Identifier[] { + switch (constructorDeclaration.parent.kind) { + case SyntaxKind.ClassDeclaration: + const classDeclaration = constructorDeclaration.parent; + return [classDeclaration.name]; + case SyntaxKind.ClassExpression: + const classExpression = constructorDeclaration.parent; + const variableDeclaration = constructorDeclaration.parent.parent; + const className = classExpression.name; + if (className) return [className, variableDeclaration.name]; + return [variableDeclaration.name]; + } + } + + function getFunctionNames(functionDeclaration: ValidFunctionDeclaration): Node[] { switch (functionDeclaration.kind) { case SyntaxKind.FunctionDeclaration: case SyntaxKind.MethodDeclaration: return [functionDeclaration.name]; case SyntaxKind.Constructor: const ctrKeyword = findChildOfKind(functionDeclaration, SyntaxKind.ConstructorKeyword, functionDeclaration.getSourceFile())!; - switch (functionDeclaration.parent.kind) { - case SyntaxKind.ClassDeclaration: - const classDeclaration = functionDeclaration.parent; - return [classDeclaration.name, ctrKeyword]; - case SyntaxKind.ClassExpression: - const classExpression = functionDeclaration.parent; - const variableDeclaration = functionDeclaration.parent.parent; - const className = classExpression.name; - if (className) return [className, ctrKeyword, variableDeclaration.name]; - return [ctrKeyword, variableDeclaration.name]; - default: return Debug.assertNever(functionDeclaration.parent); + if (functionDeclaration.parent.kind === SyntaxKind.ClassExpression) { + const variableDeclaration = functionDeclaration.parent.parent; + return [variableDeclaration.name, ctrKeyword]; } + return [ctrKeyword]; case SyntaxKind.ArrowFunction: return [functionDeclaration.parent.name]; case SyntaxKind.FunctionExpression: @@ -500,7 +499,6 @@ namespace ts.refactor.convertToNamedParameters { functionCalls: (CallExpression | NewExpression)[]; declarations: Node[]; classReferences?: ClassReferences; - unhandled: FindAllReferences.Entry[]; valid: boolean; } interface ClassReferences { diff --git a/src/services/utilities.ts b/src/services/utilities.ts index afa14f62a1d..004135a526e 100644 --- a/src/services/utilities.ts +++ b/src/services/utilities.ts @@ -1664,6 +1664,18 @@ namespace ts { return ensureScriptKind(fileName, host && host.getScriptKind && host.getScriptKind(fileName)); } + export function getSymbolTarget(symbol: Symbol): Symbol { + let next: Symbol = symbol; + while (isTransientSymbol(next) && next.target) { + next = next.target; + } + return next; + } + + function isTransientSymbol(symbol: Symbol): symbol is TransientSymbol { + return (symbol.flags & SymbolFlags.Transient) !== 0; + } + export function getUniqueSymbolId(symbol: Symbol, checker: TypeChecker) { return getSymbolId(skipAlias(symbol, checker)); }