diff --git a/src/ErrorProne.NET.Core/KeyValuePairExtensions.cs b/src/ErrorProne.NET.Core/KeyValuePairExtensions.cs new file mode 100644 index 0000000..5d07554 --- /dev/null +++ b/src/ErrorProne.NET.Core/KeyValuePairExtensions.cs @@ -0,0 +1,19 @@ +// -------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// -------------------------------------------------------------------- + +using System.Collections.Generic; + +namespace ErrorProne.NET.Core +{ + public static class KeyValuePairExtensions + { + public static void Deconstruct(this KeyValuePair pair, out TKey key, out TValue value) + { + key = pair.Key; + value = pair.Value; + } + } +} diff --git a/src/ErrorProne.NET.StructAnalyzers.CodeFixes/UseInModifierForReadOnlyStructCodeFixProvider.cs b/src/ErrorProne.NET.StructAnalyzers.CodeFixes/UseInModifierForReadOnlyStructCodeFixProvider.cs index a8b0bb8..03ea531 100644 --- a/src/ErrorProne.NET.StructAnalyzers.CodeFixes/UseInModifierForReadOnlyStructCodeFixProvider.cs +++ b/src/ErrorProne.NET.StructAnalyzers.CodeFixes/UseInModifierForReadOnlyStructCodeFixProvider.cs @@ -1,15 +1,19 @@ -using System.Collections.Immutable; +using System; +using System.Collections.Generic; +using System.Collections.Immutable; using System.Composition; using System.Diagnostics; using System.Linq; using System.Threading; using System.Threading.Tasks; +using ErrorProne.NET.Core; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CodeActions; using Microsoft.CodeAnalysis.CodeFixes; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.FindSymbols; +using Microsoft.CodeAnalysis.Text; namespace ErrorProne.NET.StructAnalyzers { @@ -39,7 +43,7 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) context.RegisterCodeFix( CodeAction.Create( title: Title, - createChangedDocument: c => AddInModifier(context.Document, declaration, c), + createChangedSolution: c => AddInModifier(context.Document, declaration, c), equivalenceKey: Title), diagnostic); } @@ -115,17 +119,233 @@ private async Task ParameterIsUsedInNonInFriendlyManner(ParameterSyntax pa return false; } - private async Task AddInModifier(Document document, ParameterSyntax paramSyntax, CancellationToken cancellationToken) + private async Task AddInModifier(Document document, ParameterSyntax paramSyntax, CancellationToken cancellationToken) { - SyntaxTriviaList trivia = paramSyntax.GetLeadingTrivia(); ; + var arguments = new Dictionary>(); + var parameters = new Dictionary> { { document.Id, new List { paramSyntax.Span } } }; - var newType = paramSyntax - .WithModifiers(paramSyntax.Modifiers.Insert(0, SyntaxFactory.Token(SyntaxKind.InKeyword))) - .WithLeadingTrivia(trivia); + var semanticModel = await document.GetSemanticModelAsync(cancellationToken); + var parameterSymbol = semanticModel.GetDeclaredSymbol(paramSyntax, cancellationToken); + if (parameterSymbol?.ContainingSymbol is IMethodSymbol containingMethod) + { + var parameterIndex = containingMethod.Parameters.IndexOf(parameterSymbol); + var parameterName = parameterSymbol.Name; + + var callers = await SymbolFinder.FindCallersAsync(containingMethod, document.Project.Solution, cancellationToken).ConfigureAwait(false); + foreach (var caller in callers) + { + foreach (var location in caller.Locations) + { + if (!location.IsInSource) + { + continue; + } + + var locationRoot = await location.SourceTree.GetRootAsync(cancellationToken).ConfigureAwait(false); + var node = locationRoot.FindNode(location.SourceSpan, getInnermostNodeForTie: true); + + var invocationExpression = node.Parent as InvocationExpressionSyntax; + if (invocationExpression is null) + { + invocationExpression = (node.Parent as MemberAccessExpressionSyntax)?.Parent as InvocationExpressionSyntax; + } + + if (invocationExpression is object) + { + ArgumentSyntax? argument = null; + var positionalArgument = TryGetArgumentAtPosition(invocationExpression.ArgumentList, parameterIndex); + if (positionalArgument is object && (positionalArgument.NameColon is null || positionalArgument.NameColon.Name.Identifier.Text == parameterName)) + { + argument = positionalArgument; + } + else + { + foreach (var argumentSyntax in invocationExpression.ArgumentList.Arguments) + { + if (argumentSyntax?.NameColon.Name.Identifier.Text != parameterName) + { + continue; + } + + argument = argumentSyntax; + break; + } + } + + if (argument is null) + { + continue; + } + + var documentId = document.Project.Solution.GetDocument(argument.SyntaxTree)?.Id; + if (documentId is null) + { + continue; + } + + if (!arguments.TryGetValue(documentId, out var argumentSpans)) + { + argumentSpans = new List(); + arguments[documentId] = argumentSpans; + } + + argumentSpans.Add(argument.Span); + } + } + } + + var implementations = await SymbolFinder.FindImplementationsAsync(containingMethod, document.Project.Solution, projects: null, cancellationToken).ConfigureAwait(false); + foreach (var implementation in implementations) + { + foreach (var location in implementation.Locations) + { + var locationRoot = await location.SourceTree.GetRootAsync(cancellationToken).ConfigureAwait(false); + var node = locationRoot.FindNode(location.SourceSpan, getInnermostNodeForTie: true); + if (node is MethodDeclarationSyntax methodDeclaration) + { + var parameterSyntax = TryGetParameterAtPosition(methodDeclaration.ParameterList, parameterIndex); + if (parameterSyntax is null) + { + continue; + } + + var documentId = document.Project.Solution.GetDocument(parameterSyntax.SyntaxTree)?.Id; + if (documentId is null) + { + continue; + } + + if (!parameters.TryGetValue(documentId, out var parameterSpans)) + { + parameterSpans = new List(); + parameters[documentId] = parameterSpans; + } + + parameterSpans.Add(parameterSyntax.Span); + } + } + } + + var overrides = await SymbolFinder.FindOverridesAsync(containingMethod, document.Project.Solution, projects: null, cancellationToken).ConfigureAwait(false); + foreach (var @override in overrides) + { + foreach (var location in @override.Locations) + { + var locationRoot = await location.SourceTree.GetRootAsync(cancellationToken).ConfigureAwait(false); + var node = locationRoot.FindNode(location.SourceSpan, getInnermostNodeForTie: true); + if (node is MethodDeclarationSyntax methodDeclaration) + { + var parameterSyntax = TryGetParameterAtPosition(methodDeclaration.ParameterList, parameterIndex); + if (parameterSyntax is null) + { + continue; + } + + var documentId = document.Project.Solution.GetDocument(methodDeclaration.SyntaxTree)?.Id; + if (documentId is null) + { + continue; + } + + if (!parameters.TryGetValue(documentId, out var parameterSpans)) + { + parameterSpans = new List(); + parameters[documentId] = parameterSpans; + } + + parameterSpans.Add(parameterSyntax.Span); + } + } + } + } + + var result = document.Project.Solution; + foreach (var (documentId, spans) in arguments) + { + var originalDocument = result.GetDocument(documentId); + var root = await originalDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + var argumentsToReplace = spans.Select(span => root.FindNode(span, getInnermostNodeForTie: true)).Select(node => node.FirstAncestorOrSelf()); + if (!parameters.TryGetValue(documentId, out var parameterSpans)) + { + parameterSpans = new List(); + } + + var parametersToReplace = parameterSpans.Select(span => root.FindNode(span, getInnermostNodeForTie: true)).Select(node => node.FirstAncestorOrSelf()); + var newRoot = root.ReplaceNodes( + argumentsToReplace.Cast().Concat(parametersToReplace), + (originalNode, rewrittenNode) => + { + if (rewrittenNode is ArgumentSyntax argument) + { + return ((ArgumentSyntax)rewrittenNode).WithRefKindKeyword(SyntaxFactory.Token(SyntaxKind.InKeyword)); + } + else + { + Debug.Assert(rewrittenNode is ParameterSyntax); + var trivia = rewrittenNode.GetLeadingTrivia(); + return ((ParameterSyntax)rewrittenNode) + .WithModifiers(((ParameterSyntax)rewrittenNode).Modifiers.Insert(0, SyntaxFactory.Token(SyntaxKind.InKeyword))) + .WithLeadingTrivia(trivia); + } + }); + + result = result.WithDocumentSyntaxRoot(documentId, newRoot, PreservationMode.PreserveValue); + } + + foreach (var (documentId, spans) in parameters) + { + if (arguments.ContainsKey(documentId)) + { + continue; + } + + var originalDocument = result.GetDocument(documentId); + var root = await originalDocument.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + var parametersToReplace = spans.Select(span => root.FindNode(span, getInnermostNodeForTie: true)).Select(node => node.FirstAncestorOrSelf()); + var newRoot = root.ReplaceNodes( + parametersToReplace, + (originalNode, rewrittenNode) => + { + var trivia = rewrittenNode.GetLeadingTrivia(); + return rewrittenNode + .WithModifiers(rewrittenNode.Modifiers.Insert(0, SyntaxFactory.Token(SyntaxKind.InKeyword))) + .WithLeadingTrivia(trivia); + }); - var root = await document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false); + result = result.WithDocumentSyntaxRoot(documentId, newRoot, PreservationMode.PreserveValue); + } + + return result; + } + + private static ParameterSyntax? TryGetParameterAtPosition(BaseParameterListSyntax? parameterList, int index) + { + if (parameterList is null) + { + return null; + } + + if (parameterList.Parameters.Count < index) + { + return null; + } + + return parameterList.Parameters[index]; + } + + private static ArgumentSyntax? TryGetArgumentAtPosition(BaseArgumentListSyntax? argumentList, int index) + { + if (argumentList is null) + { + return null; + } + + if (argumentList.Arguments.Count < index) + { + return null; + } - return document.WithSyntaxRoot(root.ReplaceNode(paramSyntax, newType)); + return argumentList.Arguments[index]; } } } \ No newline at end of file diff --git a/src/ErrorProne.NET.StructAnalyzers.Tests/UseInModifierForReadOnlyStructCodeFixProviderTests.cs b/src/ErrorProne.NET.StructAnalyzers.Tests/UseInModifierForReadOnlyStructCodeFixProviderTests.cs index 065e102..d511650 100644 --- a/src/ErrorProne.NET.StructAnalyzers.Tests/UseInModifierForReadOnlyStructCodeFixProviderTests.cs +++ b/src/ErrorProne.NET.StructAnalyzers.Tests/UseInModifierForReadOnlyStructCodeFixProviderTests.cs @@ -25,6 +25,204 @@ public async Task AddInModifier() }.WithoutGeneratedCodeVerification().RunAsync(); } + [Test] + public async Task AddInModifierToCallSites() + { + string code = @" +readonly struct FooBar +{ + public static void Foo([|FooBar fb|]) {Foo(fb);} + readonly (long, long, long) data; +}"; + + string expected = @" +readonly struct FooBar +{ + public static void Foo(in FooBar fb) {Foo(in fb);} + readonly (long, long, long) data; +}"; + + await new VerifyCS.Test + { + TestState = { Sources = { code } }, + FixedState = { Sources = { expected } }, + }.WithoutGeneratedCodeVerification().RunAsync(); + } + + [Test] + public async Task AddInModifierToImplementingMethods() + { + string code = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +interface IInterface +{ + void Method([|FooBar fb|]); +} + +class Class +{ + public void Method([|FooBar fb|]) => throw null; +} +"; + + string expected = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +interface IInterface +{ + void Method(in FooBar fb); +} + +class Class +{ + public void Method(in FooBar fb) => throw null; +} +"; + + await new VerifyCS.Test + { + TestState = { Sources = { code } }, + FixedState = { Sources = { expected } }, + }.WithoutGeneratedCodeVerification().RunAsync(); + } + + [Test] + public async Task AddInModifierToExplicitImplementingMethods() + { + string code = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +interface IInterface +{ + void Method([|FooBar fb|]); +} + +class Class : IInterface +{ + void IInterface.Method(FooBar fb) => throw null; +} +"; + + string expected = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +interface IInterface +{ + void Method(in FooBar fb); +} + +class Class : IInterface +{ + void IInterface.Method(in FooBar fb) => throw null; +} +"; + + await new VerifyCS.Test + { + TestState = { Sources = { code } }, + FixedState = { Sources = { expected } }, + }.WithoutGeneratedCodeVerification().RunAsync(); + } + + [Test] + public async Task AddInModifierToOverridingMethods() + { + string code = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +abstract class BaseClass +{ + protected abstract void Method([|FooBar fb|]); +} + +class DerivedClass : BaseClass +{ + protected override void Method(FooBar fb) => throw null; +} +"; + + string expected = @" +readonly struct FooBar +{ + readonly (long, long, long) data; +} + +abstract class BaseClass +{ + protected abstract void Method(in FooBar fb); +} + +class DerivedClass : BaseClass +{ + protected override void Method(in FooBar fb) => throw null; +} +"; + + await new VerifyCS.Test + { + TestState = { Sources = { code } }, + FixedState = { Sources = { expected } }, + }.WithoutGeneratedCodeVerification().RunAsync(); + } + + [Test] + public async Task AddInModifierToCallSites2() + { + string code1 = @" +readonly struct FooBar +{ + public static void Foo([|FooBar fb|]) { } + readonly (long, long, long) data; +}"; + string code2 = @" +class Referencer +{ + FooBar _value; + public void Method() + { + FooBar.Foo(_value); + } +}"; + + string expected1 = @" +readonly struct FooBar +{ + public static void Foo(in FooBar fb) { } + readonly (long, long, long) data; +}"; + string expected2 = @" +class Referencer +{ + FooBar _value; + public void Method() + { + FooBar.Foo(in _value); + } +}"; + + await new VerifyCS.Test + { + TestState = { Sources = { code1, code2 } }, + FixedState = { Sources = { expected1, expected2 } }, + }.WithoutGeneratedCodeVerification().RunAsync(); + } + [Test] public async Task AddInModifierWithTrivia() {