diff --git a/src/services/refactors/convertToOptionalChainExpression.ts b/src/services/refactors/convertToOptionalChainExpression.ts
index 0333506b1e3..52365fc0ba5 100644
--- a/src/services/refactors/convertToOptionalChainExpression.ts
+++ b/src/services/refactors/convertToOptionalChainExpression.ts
@@ -51,9 +51,11 @@ namespace ts.refactor.convertToOptionalChainExpression {
error: string;
};
+ type Occurrence = PropertyAccessExpression | ElementAccessExpression | Identifier;
+
interface Info {
- finalExpression: PropertyAccessExpression | CallExpression,
- occurrences: (PropertyAccessExpression | Identifier)[],
+ finalExpression: PropertyAccessExpression | ElementAccessExpression | CallExpression,
+ occurrences: Occurrence[],
expression: ValidExpression,
};
@@ -107,7 +109,7 @@ namespace ts.refactor.convertToOptionalChainExpression {
if (!finalExpression || checker.isNullableType(checker.getTypeAtLocation(finalExpression))) {
return { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_convertible_access_expression) };
- };
+ }
if ((isPropertyAccessExpression(condition) || isIdentifier(condition))
&& getMatchingStart(condition, finalExpression.expression)) {
@@ -136,8 +138,8 @@ namespace ts.refactor.convertToOptionalChainExpression {
/**
* Gets a list of property accesses that appear in matchTo and occur in sequence in expression.
*/
- function getOccurrencesInExpression(matchTo: Expression, expression: Expression): (PropertyAccessExpression | Identifier)[] | undefined {
- const occurrences: (PropertyAccessExpression | Identifier)[] = [];
+ function getOccurrencesInExpression(matchTo: Expression, expression: Expression): Occurrence[] | undefined {
+ const occurrences: Occurrence[] = [];
while (isBinaryExpression(expression) && expression.operatorToken.kind === SyntaxKind.AmpersandAmpersandToken) {
const match = getMatchingStart(skipParentheses(matchTo), skipParentheses(expression.right));
if (!match) {
@@ -157,9 +159,11 @@ namespace ts.refactor.convertToOptionalChainExpression {
/**
* Returns subchain if chain begins with subchain syntactically.
*/
- function getMatchingStart(chain: Expression, subchain: Expression): PropertyAccessExpression | Identifier | undefined {
- return (isIdentifier(subchain) || isPropertyAccessExpression(subchain)) &&
- chainStartsWith(chain, subchain) ? subchain : undefined;
+ function getMatchingStart(chain: Expression, subchain: Expression): PropertyAccessExpression | ElementAccessExpression | Identifier | undefined {
+ if (!isIdentifier(subchain) && !isPropertyAccessExpression(subchain) && !isElementAccessExpression(subchain)) {
+ return undefined;
+ }
+ return chainStartsWith(chain, subchain) ? subchain : undefined;
}
/**
@@ -167,14 +171,14 @@ namespace ts.refactor.convertToOptionalChainExpression {
*/
function chainStartsWith(chain: Node, subchain: Node): boolean {
// skip until we find a matching identifier.
- while (isCallExpression(chain) || isPropertyAccessExpression(chain)) {
- const subchainName = isPropertyAccessExpression(subchain) ? subchain.name.getText() : subchain.getText();
- if (isPropertyAccessExpression(chain) && chain.name.getText() === subchainName) break;
+ while (isCallExpression(chain) || isPropertyAccessExpression(chain) || isElementAccessExpression(chain)) {
+ if (getTextOfChainNode(chain) === getTextOfChainNode(subchain)) break;
chain = chain.expression;
}
- // check that the chains match at each access. Call chains in subchain are not valid.
- while (isPropertyAccessExpression(chain) && isPropertyAccessExpression(subchain)) {
- if (chain.name.getText() !== subchain.name.getText()) return false;
+ // check that the chains match at each access. Call chains in subchain are not valid.
+ while ((isPropertyAccessExpression(chain) && isPropertyAccessExpression(subchain)) ||
+ (isElementAccessExpression(chain) && isElementAccessExpression(subchain))) {
+ if (getTextOfChainNode(chain) !== getTextOfChainNode(subchain)) return false;
chain = chain.expression;
subchain = subchain.expression;
}
@@ -182,6 +186,19 @@ namespace ts.refactor.convertToOptionalChainExpression {
return isIdentifier(chain) && isIdentifier(subchain) && chain.getText() === subchain.getText();
}
+ function getTextOfChainNode(node: Node): string | undefined {
+ if (isIdentifier(node) || isStringOrNumericLiteralLike(node)) {
+ return node.getText();
+ }
+ if (isPropertyAccessExpression(node)) {
+ return getTextOfChainNode(node.name);
+ }
+ if (isElementAccessExpression(node)) {
+ return getTextOfChainNode(node.argumentExpression);
+ }
+ return undefined;
+ }
+
/**
* Find the least ancestor of the input node that is a valid type for extraction and contains the input span.
*/
@@ -229,7 +246,7 @@ namespace ts.refactor.convertToOptionalChainExpression {
* it is followed by a different binary operator.
* @param node the right child of a binary expression or a call expression.
*/
- function getFinalExpressionInChain(node: Expression): CallExpression | PropertyAccessExpression | undefined {
+ function getFinalExpressionInChain(node: Expression): CallExpression | PropertyAccessExpression | ElementAccessExpression | undefined {
// foo && |foo.bar === 1|; - here the right child of the && binary expression is another binary expression.
// the rightmost member of the && chain should be the leftmost child of that expression.
node = skipParentheses(node);
@@ -237,7 +254,7 @@ namespace ts.refactor.convertToOptionalChainExpression {
return getFinalExpressionInChain(node.left);
}
// foo && |foo.bar()()| - nested calls are treated like further accesses.
- else if ((isPropertyAccessExpression(node) || isCallExpression(node)) && !isOptionalChain(node)) {
+ else if ((isPropertyAccessExpression(node) || isElementAccessExpression(node) || isCallExpression(node)) && !isOptionalChain(node)) {
return node;
}
return undefined;
@@ -246,8 +263,8 @@ namespace ts.refactor.convertToOptionalChainExpression {
/**
* Creates an access chain from toConvert with '?.' accesses at expressions appearing in occurrences.
*/
- function convertOccurrences(checker: TypeChecker, toConvert: Expression, occurrences: (PropertyAccessExpression | Identifier)[]): Expression {
- if (isPropertyAccessExpression(toConvert) || isCallExpression(toConvert)) {
+ function convertOccurrences(checker: TypeChecker, toConvert: Expression, occurrences: Occurrence[]): Expression {
+ if (isPropertyAccessExpression(toConvert) || isElementAccessExpression(toConvert) || isCallExpression(toConvert)) {
const chain = convertOccurrences(checker, toConvert.expression, occurrences);
const lastOccurrence = occurrences.length > 0 ? occurrences[occurrences.length - 1] : undefined;
const isOccurrence = lastOccurrence?.getText() === toConvert.expression.getText();
@@ -262,6 +279,11 @@ namespace ts.refactor.convertToOptionalChainExpression {
factory.createPropertyAccessChain(chain, factory.createToken(SyntaxKind.QuestionDotToken), toConvert.name) :
factory.createPropertyAccessChain(chain, toConvert.questionDotToken, toConvert.name);
}
+ else if (isElementAccessExpression(toConvert)) {
+ return isOccurrence ?
+ factory.createElementAccessChain(chain, factory.createToken(SyntaxKind.QuestionDotToken), toConvert.argumentExpression) :
+ factory.createElementAccessChain(chain, toConvert.questionDotToken, toConvert.argumentExpression);
+ }
}
return toConvert;
}
@@ -270,7 +292,7 @@ namespace ts.refactor.convertToOptionalChainExpression {
const { finalExpression, occurrences, expression } = info;
const firstOccurrence = occurrences[occurrences.length - 1];
const convertedChain = convertOccurrences(checker, finalExpression, occurrences);
- if (convertedChain && (isPropertyAccessExpression(convertedChain) || isCallExpression(convertedChain))) {
+ if (convertedChain && (isPropertyAccessExpression(convertedChain) || isElementAccessExpression(convertedChain) || isCallExpression(convertedChain))) {
if (isBinaryExpression(expression)) {
changes.replaceNodeRange(sourceFile, firstOccurrence, finalExpression, convertedChain);
}
diff --git a/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression1.ts b/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression1.ts
new file mode 100644
index 00000000000..df11f0e3417
--- /dev/null
+++ b/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression1.ts
@@ -0,0 +1,18 @@
+///
+
+////const a = {
+//// b: { c: 1 }
+////}
+/////*a*/a && a['b'] && a['b']['c']/*b*/
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Convert to optional chain expression",
+ actionName: "Convert to optional chain expression",
+ actionDescription: "Convert to optional chain expression",
+ newContent:
+`const a = {
+ b: { c: 1 }
+}
+a?.['b']?.['c']`
+});
diff --git a/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression2.ts b/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression2.ts
new file mode 100644
index 00000000000..2460fa86c70
--- /dev/null
+++ b/tests/cases/fourslash/refactorConvertToOptionalChainExpression_ElementAccessExpression2.ts
@@ -0,0 +1,18 @@
+///
+
+////const a = {
+//// b: { c: 1 }
+////}
+/////*a*/a && a.b && a.b['c']/*b*/
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Convert to optional chain expression",
+ actionName: "Convert to optional chain expression",
+ actionDescription: "Convert to optional chain expression",
+ newContent:
+`const a = {
+ b: { c: 1 }
+}
+a?.b?.['c']`
+});