diff --git a/src/services/refactors/convertToNamedParameters.ts b/src/services/refactors/convertToNamedParameters.ts index fa46f56b169..a6d75662e47 100644 --- a/src/services/refactors/convertToNamedParameters.ts +++ b/src/services/refactors/convertToNamedParameters.ts @@ -10,7 +10,7 @@ namespace ts.refactor.convertToNamedParameters { function getAvailableActions(context: RefactorContext): ReadonlyArray { const { file, startPosition } = context; - const func = getFunctionDeclarationAtPosition(file, startPosition); + const func = getFunctionDeclarationAtPosition(file, startPosition, context.program.getTypeChecker()); if (!func) return emptyArray; return [{ @@ -25,32 +25,102 @@ namespace ts.refactor.convertToNamedParameters { function getEditsForAction(context: RefactorContext, actionName: string): RefactorEditInfo | undefined { Debug.assert(actionName === actionNameNamedParameters); - const { file, startPosition, program, cancellationToken } = context; - const func = getFunctionDeclarationAtPosition(file, startPosition); + const { file, startPosition, program, cancellationToken, host } = context; + const func = getFunctionDeclarationAtPosition(file, startPosition, program.getTypeChecker()); if (!func || !cancellationToken) return undefined; - const newParamDeclaration = createObjectParameter(func); - // const funcRefs = FindAllReferences.getReferenceEntriesForNode(-1, func.name, program, program.getSourceFiles(), cancellationToken); - - const edits = textChanges.ChangeTracker.with(context, t => t.replaceNodeRange(file, first(func.parameters), last(func.parameters), newParamDeclaration)); + const edits = textChanges.ChangeTracker.with(context, t => doChange(file, program, cancellationToken, host, t, func)); return { renameFilename: undefined, renameLocation: undefined, edits }; } - function getFunctionDeclarationAtPosition(file: SourceFile, startPosition: number): ValidFunctionDeclaration | undefined { + function doChange(sourceFile: SourceFile, program: Program, cancellationToken: CancellationToken, host: LanguageServiceHost, changes: textChanges.ChangeTracker, functionDeclaration: ValidFunctionDeclaration): void { + const newParamDeclaration = getSynthesizedDeepClone(createObjectParameter(functionDeclaration, program, host)); + changes.replaceNodeRange(sourceFile, first(functionDeclaration.parameters), last(functionDeclaration.parameters), newParamDeclaration); + + const nameNode = getFunctionDeclarationName(functionDeclaration); + const functionRefs = FindAllReferences.getReferenceEntriesForNode(-1, nameNode, program, program.getSourceFiles(), cancellationToken); + const functionCalls = getDirectFunctionCalls(functionRefs); + + forEach(functionCalls, call => { + if (call.arguments && call.arguments.length) { + const newArguments = getSynthesizedDeepClone(createArgumentObject(functionDeclaration, call)); + changes.replaceNodeRange(getSourceFileOfNode(call), first(call.arguments), last(call.arguments), newArguments); + }}); + } + + function createArgumentObject(func: ValidFunctionDeclaration, funcCall: CallExpression | NewExpression): ObjectLiteralExpression { + const properties = map(funcCall.arguments, (arg, i) => createPropertyAssignment(getParameterName(func.parameters[i]), arg)); + return createObjectLiteral(properties, /*multiLine*/ false); + } + + function getDirectFunctionCalls(referenceEntries: ReadonlyArray | undefined): ReadonlyArray { + return mapDefined(referenceEntries, (entry) => { + if (entry.kind !== FindAllReferences.EntryKind.Span && entry.node.parent) { + const functionRef = entry.node; + const parent = functionRef.parent; + switch (parent.kind) { + // Function call (foo(...)) + case SyntaxKind.CallExpression: + const callExpression = tryCast(parent, isCallExpression); + if (callExpression && callExpression.expression === functionRef) { + return callExpression; + } + break; + // Constructor call (new Foo(...)) + case SyntaxKind.NewExpression: + const newExpression = tryCast(parent, isNewExpression); + if (newExpression && newExpression.expression === functionRef) { + return newExpression; + } + break; + // Method call (x.foo(...)) + case SyntaxKind.PropertyAccessExpression: + const propertyAccessExpression = tryCast(parent, isPropertyAccessExpression); + if (propertyAccessExpression && propertyAccessExpression.parent && propertyAccessExpression.name === functionRef) { + const callExpression = tryCast(propertyAccessExpression.parent, isCallExpression); + if (callExpression && callExpression.expression === propertyAccessExpression) { + return callExpression; + } + } + break; + // Method call (x['foo'](...)) + case SyntaxKind.ElementAccessExpression: + const elementAccessExpression = tryCast(parent, isElementAccessExpression); + if (elementAccessExpression && elementAccessExpression.parent && elementAccessExpression.argumentExpression === functionRef) { + const callExpression = tryCast(elementAccessExpression.parent, isCallExpression); + if (callExpression && callExpression.expression === elementAccessExpression) { + return callExpression; + } + } + break; + } + } + return undefined; + }); + } + + function getFunctionDeclarationAtPosition(file: SourceFile, startPosition: number, checker: TypeChecker): ValidFunctionDeclaration | undefined { const node = getTokenAtPosition(file, startPosition); const func = getContainingFunction(node); - // TODO: check range - if (!func || !isValidFunctionDeclaration(func)) return undefined; + if (!func || !isValidFunctionDeclaration(func, checker) || !rangeContainsRange(func, node) || (func.body && rangeContainsRange(func.body, node))) return undefined; return func; } - function isValidFunctionDeclaration(func: SignatureDeclaration): func is ValidFunctionDeclaration { + function isValidFunctionDeclaration(func: SignatureDeclaration, checker: TypeChecker): func is ValidFunctionDeclaration { switch (func.kind) { case SyntaxKind.FunctionDeclaration: case SyntaxKind.MethodDeclaration: + return !!func.name && isPropertyName(func.name) && isValidParameterNodeArray(func.parameters) && !!func.body && !checker.isImplementationOfOverload(func); case SyntaxKind.Constructor: - return !!(func.name && isValidParameterNodeArray(func.parameters)); - default: + if (isClassDeclaration(func.parent)) { + return !!func.parent.name && isValidParameterNodeArray(func.parameters) && !!func.body && !checker.isImplementationOfOverload(func); + } + else { + return isVariableDeclaration(func.parent.parent) && isVarConst(func.parent.parent) && isValidParameterNodeArray(func.parameters) && !!func.body && !checker.isImplementationOfOverload(func); + } + case SyntaxKind.FunctionExpression: + case SyntaxKind.ArrowFunction: + return isVariableDeclaration(func.parent) && isVarConst(func.parent) && isValidParameterNodeArray(func.parameters); } return false; } @@ -60,29 +130,25 @@ namespace ts.refactor.convertToNamedParameters { } function isValidParameterDeclaration(paramDecl: ParameterDeclaration): paramDecl is ValidParameterDeclaration { - return !paramDecl.modifiers && !paramDecl.dotDotDotToken && isIdentifier(paramDecl.name) && !paramDecl.initializer; + return !paramDecl.modifiers && !paramDecl.dotDotDotToken && isIdentifier(paramDecl.name); } - function createParamTypeNode(func: ValidFunctionDeclaration): TypeLiteralNode { - const members = map(func.parameters, createPropertySignatureFromParameterDeclaration); - const typeNode = addEmitFlags(createTypeLiteralNode(members), EmitFlags.SingleLine); - // TODO: add emit flags on create function in factory - return typeNode; - } - - function createPropertySignatureFromParameterDeclaration(paramDeclaration: ValidParameterDeclaration): PropertySignature { - return createPropertySignature( - /*modifiers*/ undefined, - paramDeclaration.name, - paramDeclaration.questionToken, - paramDeclaration.type, - paramDeclaration.initializer); - } - - function createObjectParameter(func: ValidFunctionDeclaration): ParameterDeclaration { - const bindingElements = map(func.parameters, param => createBindingElement(/*dotDotDotToken*/ undefined, /*propertyName*/ undefined, getTextOfIdentifierOrLiteral(param.name))); + function createObjectParameter(functionDeclaration: ValidFunctionDeclaration, program: Program, host: LanguageServiceHost): ParameterDeclaration { + const bindingElements = map( + functionDeclaration.parameters, + paramDecl => { + return createBindingElement( + /*dotDotDotToken*/ undefined, + /*propertyName*/ undefined, + getParameterName(paramDecl), + paramDecl.initializer); }); const paramName = createObjectBindingPattern(bindingElements); - const paramType = createParamTypeNode(func); + const paramType = createParamTypeNode(functionDeclaration); + + let objectInitializer: Expression | undefined; + if (every(functionDeclaration.parameters, param => !!param.initializer || !!param.questionToken)) { + objectInitializer = createObjectLiteral(); + } return createParameter( /*decorators*/ undefined, @@ -90,21 +156,86 @@ namespace ts.refactor.convertToNamedParameters { /*dotDotDotToken*/ undefined, paramName, /*questionToken*/ undefined, - paramType); + paramType, + objectInitializer); + + function createParamTypeNode(func: ValidFunctionDeclaration): TypeLiteralNode { + const members = map(func.parameters, createPropertySignatureFromParameterDeclaration); + const typeNode = addEmitFlags(createTypeLiteralNode(members), EmitFlags.SingleLine); + return typeNode; + } + + function createPropertySignatureFromParameterDeclaration(paramDeclaration: ValidParameterDeclaration): PropertySignature { + let paramType = paramDeclaration.type; + if (paramDeclaration.initializer && !paramType) { + const checker = program.getTypeChecker(); + const type = checker.getBaseTypeOfLiteralType(checker.getTypeAtLocation(paramDeclaration.initializer)); + paramType = getTypeNodeIfAccessible(type, paramDeclaration, program, host); + } + return createPropertySignature( + /*modifiers*/ undefined, + paramDeclaration.name, + paramDeclaration.initializer ? createToken(SyntaxKind.QuestionToken) : paramDeclaration.questionToken, + paramType, + /*initializer*/ undefined); + } } - interface ValidFunctionDeclaration extends MethodDeclaration { - name: PropertyName; - body?: FunctionBody; - typeParameters?: NodeArray; + function getParameterName(paramDecl: ValidParameterDeclaration): string { + return getTextOfIdentifierOrLiteral(paramDecl.name); + } + + function getFunctionDeclarationName(functionDeclaration: ValidFunctionDeclaration): Node { + switch (functionDeclaration.kind) { + case SyntaxKind.FunctionDeclaration: + case SyntaxKind.MethodDeclaration: + return functionDeclaration.name; + case SyntaxKind.Constructor: + switch (functionDeclaration.parent.kind) { + case SyntaxKind.ClassDeclaration: + return functionDeclaration.parent.name; + case SyntaxKind.ClassExpression: + return functionDeclaration.parent.parent.name; + default: return Debug.assertNever(functionDeclaration.parent); + } + case SyntaxKind.ArrowFunction: + case SyntaxKind.FunctionExpression: + return functionDeclaration.parent.name; + } + } + + interface ValidConstructor extends ConstructorDeclaration { + parent: (ClassDeclaration & { name: Identifier }) | (ClassExpression & { parent: VariableDeclaration }); + parameters: NodeArray; + body: FunctionBody; + } + + interface ValidFunction extends FunctionDeclaration { + name: Identifier; + parameters: NodeArray; + body: FunctionBody; + } + + interface ValidMethod extends MethodDeclaration { + parameters: NodeArray; + body: FunctionBody; + } + + interface ValidFunctionExpression extends FunctionExpression { + parent: VariableDeclaration; parameters: NodeArray; } + interface ValidArrowFunction extends ArrowFunction { + parent: VariableDeclaration; + parameters: NodeArray; + } + + type ValidFunctionDeclaration = ValidConstructor | ValidFunction | ValidMethod | ValidArrowFunction | ValidFunctionExpression; + interface ValidParameterDeclaration extends ParameterDeclaration { name: Identifier; - type: TypeNode; dotDotDotToken: undefined; modifiers: undefined; - initializer: undefined; } } \ No newline at end of file