diff --git a/src/services/findAllReferences.ts b/src/services/findAllReferences.ts index c1a7408fae7..67a967ec8b9 100644 --- a/src/services/findAllReferences.ts +++ b/src/services/findAllReferences.ts @@ -496,11 +496,10 @@ namespace ts.FindAllReferences.Core { const { text = stripQuotes(getDeclaredName(this.checker, symbol, location)), allSearchSymbols = undefined } = searchOptions; const escapedText = escapeIdentifier(text); const parents = this.options.implementations && getParentSymbolsOfPropertyAccess(location, symbol, this.checker); - return { location, symbol, comingFrom, text, escapedText, parents, includes }; - - function includes(referenceSymbol: Symbol): boolean { - return allSearchSymbols ? contains(allSearchSymbols, referenceSymbol) : referenceSymbol === symbol; - } + return { + location, symbol, comingFrom, text, escapedText, parents, + includes: referenceSymbol => allSearchSymbols ? contains(allSearchSymbols, referenceSymbol) : referenceSymbol === symbol, + }; } private readonly symbolIdToReferences: Entry[][] = []; diff --git a/src/services/importTracker.ts b/src/services/importTracker.ts index 585af7ebd31..6a0b9167997 100644 --- a/src/services/importTracker.ts +++ b/src/services/importTracker.ts @@ -436,8 +436,8 @@ namespace ts.FindAllReferences { if (parent.kind === SyntaxKind.PropertyAccessExpression) { // When accessing an export of a JS module, there's no alias. The symbol will still be flagged as an export even though we're at the use. // So check that we are at the declaration. - return symbol.declarations.some(d => d === parent) && parent.parent.kind === ts.SyntaxKind.BinaryExpression - ? getSpecialPropertyExport(parent.parent as ts.BinaryExpression, /*useLhsSymbol*/ false) + return symbol.declarations.some(d => d === parent) && isBinaryExpression(parent.parent) + ? getSpecialPropertyExport(parent.parent, /*useLhsSymbol*/ false) : undefined; } else { @@ -449,31 +449,41 @@ namespace ts.FindAllReferences { else { const exportNode = getExportNode(parent); if (exportNode && hasModifier(exportNode, ModifierFlags.Export)) { - if (exportNode.kind === SyntaxKind.ImportEqualsDeclaration && (exportNode as ImportEqualsDeclaration).moduleReference === node) { + if (isImportEqualsDeclaration(exportNode) && exportNode.moduleReference === node) { // We're at `Y` in `export import X = Y`. This is not the exported symbol, the left-hand-side is. So treat this as an import statement. if (comingFromExport) { return undefined; } - const lhsSymbol = checker.getSymbolAtLocation((exportNode as ImportEqualsDeclaration).name); + const lhsSymbol = checker.getSymbolAtLocation(exportNode.name); return { kind: ImportExport.Import, symbol: lhsSymbol, isNamedImport: false }; } else { return exportInfo(symbol, getExportKindForDeclaration(exportNode)); } } - else if (parent.kind === SyntaxKind.ExportAssignment) { - // Get the symbol for the `export =` node; its parent is the module it's the export of. - const exportingModuleSymbol = parent.symbol.parent; - Debug.assert(!!exportingModuleSymbol); - return { kind: ImportExport.Export, symbol, exportInfo: { exportingModuleSymbol, exportKind: ExportKind.ExportEquals } }; + // If we are in `export = a;`, `parent` is the export assignment. + else if (isExportAssignment(parent)) { + return getExportAssignmentExport(parent); } - else if (parent.kind === ts.SyntaxKind.BinaryExpression) { - return getSpecialPropertyExport(parent as ts.BinaryExpression, /*useLhsSymbol*/ true); + // If we are in `export = class A {};` at `A`, `parent.parent` is the export assignment. + else if (isExportAssignment(parent.parent)) { + return getExportAssignmentExport(parent.parent); } - else if (parent.parent.kind === SyntaxKind.BinaryExpression) { - return getSpecialPropertyExport(parent.parent as ts.BinaryExpression, /*useLhsSymbol*/ true); + // Similar for `module.exports =` and `exports.A =`. + else if (isBinaryExpression(parent)) { + return getSpecialPropertyExport(parent, /*useLhsSymbol*/ true); } + else if (isBinaryExpression(parent.parent)) { + return getSpecialPropertyExport(parent.parent, /*useLhsSymbol*/ true); + } + } + + function getExportAssignmentExport(ex: ExportAssignment): ExportedSymbol { + // Get the symbol for the `export =` node; its parent is the module it's the export of. + const exportingModuleSymbol = ex.symbol.parent; + Debug.assert(!!exportingModuleSymbol); + return { kind: ImportExport.Export, symbol, exportInfo: { exportingModuleSymbol, exportKind: ExportKind.ExportEquals } }; } function getSpecialPropertyExport(node: ts.BinaryExpression, useLhsSymbol: boolean): ExportedSymbol | undefined { @@ -496,21 +506,21 @@ namespace ts.FindAllReferences { function getImport(): ImportedSymbol | undefined { const isImport = isNodeImport(node); - if (!isImport) return; + if (!isImport) return undefined; // A symbol being imported is always an alias. So get what that aliases to find the local symbol. let importedSymbol = checker.getImmediateAliasedSymbol(symbol); - if (importedSymbol) { - // Search on the local symbol in the exporting module, not the exported symbol. - importedSymbol = skipExportSpecifierSymbol(importedSymbol, checker); - // Similarly, skip past the symbol for 'export =' - if (importedSymbol.name === "export=") { - importedSymbol = checker.getImmediateAliasedSymbol(importedSymbol); - } + if (!importedSymbol) return undefined; - if (symbolName(importedSymbol) === symbol.name) { // If this is a rename import, do not continue searching. - return { kind: ImportExport.Import, symbol: importedSymbol, ...isImport }; - } + // Search on the local symbol in the exporting module, not the exported symbol. + importedSymbol = skipExportSpecifierSymbol(importedSymbol, checker); + // Similarly, skip past the symbol for 'export =' + if (importedSymbol.name === "export=") { + importedSymbol = getExportEqualsLocalSymbol(importedSymbol, checker); + } + + if (symbolName(importedSymbol) === symbol.name) { // If this is a rename import, do not continue searching. + return { kind: ImportExport.Import, symbol: importedSymbol, ...isImport }; } } @@ -525,6 +535,22 @@ namespace ts.FindAllReferences { } } + function getExportEqualsLocalSymbol(importedSymbol: Symbol, checker: TypeChecker): Symbol { + if (importedSymbol.flags & SymbolFlags.Alias) { + return checker.getImmediateAliasedSymbol(importedSymbol); + } + + const decl = importedSymbol.valueDeclaration; + if (isExportAssignment(decl)) { // `export = class {}` + return decl.expression.symbol; + } + else if (isBinaryExpression(decl)) { // `module.exports = class {}` + return decl.right.symbol; + } + Debug.fail(); + } + + // If a reference is a class expression, the exported node would be its parent. // If a reference is a variable declaration, the exported node would be the variable statement. function getExportNode(parent: Node): Node | undefined { if (parent.kind === SyntaxKind.VariableDeclaration) { diff --git a/tests/cases/fourslash/findAllRefsClassExpression0.ts b/tests/cases/fourslash/findAllRefsClassExpression0.ts new file mode 100644 index 00000000000..50abfae0230 --- /dev/null +++ b/tests/cases/fourslash/findAllRefsClassExpression0.ts @@ -0,0 +1,16 @@ +/// + +// @Filename: /a.ts +////export = class [|{| "isWriteAccess": true, "isDefinition": true |}A|] { +//// m() { [|A|]; } +////}; + +// @Filename: /b.ts +////import [|{| "isWriteAccess": true, "isDefinition": true |}A|] = require("./a"); +////[|A|]; + +const [r0, r1, r2, r3] = test.ranges(); +const defs = { definition: "(local class) A", ranges: [r0, r1] }; +const imports = { definition: 'import A = require("./a")', ranges: [r2, r3] }; +verify.referenceGroups([r0, r1], [defs, imports]); +verify.referenceGroups([r2, r3], [imports, defs]); diff --git a/tests/cases/fourslash/findAllRefsClassExpression1.ts b/tests/cases/fourslash/findAllRefsClassExpression1.ts new file mode 100644 index 00000000000..bd581871842 --- /dev/null +++ b/tests/cases/fourslash/findAllRefsClassExpression1.ts @@ -0,0 +1,17 @@ +/// + +// @allowJs: true + +// @Filename: /a.js +////module.exports = class [|{| "isWriteAccess": true, "isDefinition": true |}A|] {}; + +// @Filename: /b.js +////import [|{| "isWriteAccess": true, "isDefinition": true |}A|] = require("./a"); +////[|A|]; + +const [r0, r1, r2] = test.ranges(); +const defs = { definition: "(local class) A", ranges: [r0] }; +const imports = { definition: 'import A = require("./a")', ranges: [r1, r2] }; +verify.referenceGroups([r0], [defs, imports]); +verify.referenceGroups([r1, r2], [imports, defs]); + diff --git a/tests/cases/fourslash/findAllRefsClassExpression2.ts b/tests/cases/fourslash/findAllRefsClassExpression2.ts new file mode 100644 index 00000000000..ce2fc8bbf3d --- /dev/null +++ b/tests/cases/fourslash/findAllRefsClassExpression2.ts @@ -0,0 +1,16 @@ +/// + +// @allowJs: true + +// @Filename: /a.js +////exports.[|{| "isWriteAccess": true, "isDefinition": true |}A|] = class {}; + +// @Filename: /b.js +////import { [|{| "isWriteAccess": true, "isDefinition": true |}A|] } from "./a"; +////[|A|]; + +const [r0, r1, r2] = test.ranges(); +const defs = { definition: "(property) A: typeof (Anonymous class)", ranges: [r0] }; +const imports = { definition: "import A", ranges: [r1, r2] }; +verify.referenceGroups([r0], [defs, imports]); +verify.referenceGroups([r1, r2], [imports, defs]);