Skip to content

Commit

Permalink
Fix support for externally declared records
Browse files Browse the repository at this point in the history
FIxes improper codec codegen for records declared in referenced projects/assemblies. Roslyn does not guarantee the symbols contain the backing fields for generated properties (see dotnet/roslyn#72374 (comment)) and it also doesn't even report `record struct` symbols as records at all (see dotnet/roslyn#69326).

This makes for a very inconsistent experience when dealing with types defined in external assemblies that don't use the Orleans SDK themselves.

We implement a heuristics here to determine primary constructors that can be relied upon to detect them consistently:
1. A ctor with non-zero parameters
2. All parameters match by name exactly with corresponding properties
3. All matching properties have getter AND setter annotated with [CompilerGenerated].

In addition, since the backing field isn't available at all in these records, and the corresponding property isn't settable (it's generated as `init set`), we leverage unsafe accessors (see https://learn.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.unsafeaccessorattribute?view=net-8.0) instead. The code checks whether the `FieldAccessorDescription` has an initializer syntax or not to determine whether to generate the original code or the new accessor version.

The signature of the accessor matches the delegate that is generated for the regular backing field case, so there is no need to modify other call sites.

Fixes #9092
  • Loading branch information
kzu committed Aug 6, 2024
1 parent 93af1ee commit f85d424
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 16 deletions.
35 changes: 31 additions & 4 deletions src/Orleans.CodeGenerator/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ public CompilationUnitSyntax GenerateCode(CancellationToken cancellationToken)
if (symbol.IsRecord)
{
// If there is a primary constructor then that will be declared before the copy constructor
// A record always generates a copy constructor and marks it as implicitly declared
// A record always generates a copy constructor and marks it as compiler generated
// todo: find an alternative to this magic
var potentialPrimaryConstructor = symbol.Constructors[0];
if (!potentialPrimaryConstructor.IsImplicitlyDeclared)
if (!potentialPrimaryConstructor.IsImplicitlyDeclared && !potentialPrimaryConstructor.IsCompilerGenerated())
{
constructorParameters = potentialPrimaryConstructor.Parameters;
}
Expand All @@ -160,6 +160,23 @@ public CompilationUnitSyntax GenerateCode(CancellationToken cancellationToken)
{
constructorParameters = annotatedConstructors[0].Parameters;
}
else
{
// record structs from referenced assemblies do not return IsRecord=true
// above. See https://github.com/dotnet/roslyn/issues/69326
// So we implement the same heuristics from ShouldIncludePrimaryConstructorParameters
// to detect a primary constructor.
var properties = symbol.GetMembers().OfType<IPropertySymbol>().ToImmutableArray();
var primaryConstructor = symbol.GetMembers()
.OfType<IMethodSymbol>()
.Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0)
// Check for a ctor where all parameters have a corresponding compiler-generated prop.
.FirstOrDefault(ctor => ctor.Parameters.All(prm =>
properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated())));

if (primaryConstructor != null)
constructorParameters = primaryConstructor.Parameters;
}
}
}

Expand Down Expand Up @@ -273,8 +290,18 @@ bool ShouldIncludePrimaryConstructorParameters(INamedTypeSymbol t)
}
}

// Default to true for records, false otherwise.
return t.IsRecord;
// Default to true for records.
if (t.IsRecord)
return true;

var properties = t.GetMembers().OfType<IPropertySymbol>().ToImmutableArray();

return t.GetMembers()
.OfType<IMethodSymbol>()
.Where(m => m.MethodKind == MethodKind.Constructor && m.Parameters.Length > 0)
// Check for a ctor where all parameters have a corresponding compiler-generated prop.
.Any(ctor => ctor.Parameters.All(prm =>
properties.Any(prop => prop.Name.Equals(prm.Name, StringComparison.Ordinal) && prop.IsCompilerGenerated())));
}
}
}
Expand Down
30 changes: 29 additions & 1 deletion src/Orleans.CodeGenerator/CopierGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,39 @@ static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription des
{
switch (description)
{
case FieldAccessorDescription accessor:
case FieldAccessorDescription accessor when accessor.InitializationSyntax != null:
return
FieldDeclaration(VariableDeclaration(accessor.FieldType,
SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax)))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
case FieldAccessorDescription accessor when accessor.InitializationSyntax == null:
//[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")]
//extern static void SetAmount(External instance, int value);
return
MethodDeclaration(
PredefinedType(Token(SyntaxKind.VoidKeyword)),
accessor.AccessorName)
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword))
.AddAttributeLists(AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor"))
.AddArgumentListArguments(
AttributeArgument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"),
IdentifierName("Method"))),
AttributeArgument(
LiteralExpression(
SyntaxKind.StringLiteralExpression,
Literal($"set_{accessor.FieldName}")))
.WithNameEquals(NameEquals("Name"))))))
.WithParameterList(
ParameterList(SeparatedList(new[]
{
Parameter(Identifier("instance")).WithType(accessor.ContainingType),
Parameter(Identifier("value")).WithType(description.FieldType)
})))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
default:
return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
Expand Down
12 changes: 12 additions & 0 deletions src/Orleans.CodeGenerator/FieldIdAssignmentHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ private bool ExtractFieldIdAnnotations()
{
_symbols[member] = (id.Value, false);
}
else if (PropertyUtility.GetMatchingPrimaryConstructorParameter(prop, _constructorParameters) is { } prm)
{
id = CodeGenerator.GetId(_libraryTypes, prop);
if (id.HasValue)
{
_symbols[member] = (id.Value, true);
}
else
{
_symbols[member] = ((uint)_constructorParameters.IndexOf(prm), true);
}
}
}

if (member is IFieldSymbol field)
Expand Down
16 changes: 16 additions & 0 deletions src/Orleans.CodeGenerator/PropertyUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,22 @@ public static class PropertyUtility
return GetMatchingProperty(field, field.ContainingType.GetMembers());
}

public static bool IsCompilerGenerated(this ISymbol? symbol)
=> symbol?.GetAttributes().Any(a => a.AttributeClass?.Name == "CompilerGeneratedAttribute") == true;

public static bool IsCompilerGenerated(this IPropertySymbol? property)
=> property?.GetMethod.IsCompilerGenerated() == true && property.SetMethod.IsCompilerGenerated();

public static IParameterSymbol? GetMatchingPrimaryConstructorParameter(IPropertySymbol property, IEnumerable<IParameterSymbol> constructorParameters)
{
if (!property.IsCompilerGenerated())
return null;

return constructorParameters.FirstOrDefault(p =>
string.Equals(p.Name, property.Name, StringComparison.Ordinal) &&
SymbolEqualityComparer.Default.Equals(p.Type, property.Type));
}

public static IPropertySymbol? GetMatchingProperty(IFieldSymbol field, IEnumerable<ISymbol> memberSymbols)
{
var propertyName = PropertyMatchRegex.Match(field.Name);
Expand Down
59 changes: 49 additions & 10 deletions src/Orleans.CodeGenerator/SerializerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,39 @@ static MemberDeclarationSyntax GetFieldDeclaration(GeneratedFieldDescription des
SingletonSeparatedList(VariableDeclarator(type.FieldName)
.WithInitializer(EqualsValueClause(TypeOfExpression(type.CodecFieldType))))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
case FieldAccessorDescription accessor:
case FieldAccessorDescription accessor when accessor.InitializationSyntax != null:
return
FieldDeclaration(VariableDeclaration(accessor.FieldType,
SingletonSeparatedList(VariableDeclarator(accessor.FieldName).WithInitializer(EqualsValueClause(accessor.InitializationSyntax)))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.StaticKeyword), Token(SyntaxKind.ReadOnlyKeyword));
case FieldAccessorDescription accessor when accessor.InitializationSyntax == null:
//[UnsafeAccessor(UnsafeAccessorKind.Method, Name = "set_Amount")]
//extern static void SetAmount(External instance, int value);
return
MethodDeclaration(
PredefinedType(Token(SyntaxKind.VoidKeyword)),
accessor.AccessorName)
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ExternKeyword), Token(SyntaxKind.StaticKeyword))
.AddAttributeLists(AttributeList(SingletonSeparatedList(
Attribute(IdentifierName("System.Runtime.CompilerServices.UnsafeAccessor"))
.AddArgumentListArguments(
AttributeArgument(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
IdentifierName("System.Runtime.CompilerServices.UnsafeAccessorKind"),
IdentifierName("Method"))),
AttributeArgument(
LiteralExpression(
SyntaxKind.StringLiteralExpression,
Literal($"set_{accessor.FieldName}")))
.WithNameEquals(NameEquals("Name"))))))
.WithParameterList(
ParameterList(SeparatedList(new[]
{
Parameter(Identifier("instance")).WithType(accessor.ContainingType),
Parameter(Identifier("value")).WithType(description.FieldType)
})))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken));
default:
return FieldDeclaration(VariableDeclaration(description.FieldType, SingletonSeparatedList(VariableDeclarator(description.FieldName))))
.AddModifiers(Token(SyntaxKind.PrivateKeyword), Token(SyntaxKind.ReadOnlyKeyword));
Expand Down Expand Up @@ -1069,11 +1097,16 @@ public CodecFieldTypeFieldDescription(TypeSyntax fieldType, string fieldName, Ty

internal sealed class FieldAccessorDescription : GeneratedFieldDescription
{
public FieldAccessorDescription(TypeSyntax fieldType, string fieldName, ExpressionSyntax initializationSyntax) : base(fieldType, fieldName)
=> InitializationSyntax = initializationSyntax;
public FieldAccessorDescription(TypeSyntax containingType, TypeSyntax fieldType, string fieldName, string accessorName, ExpressionSyntax initializationSyntax = null) : base(fieldType, fieldName)
{
ContainingType = containingType;
AccessorName = accessorName;
InitializationSyntax = initializationSyntax;
}

public override bool IsInjected => false;

public readonly string AccessorName;
public readonly TypeSyntax ContainingType;
public readonly ExpressionSyntax InitializationSyntax;
}

Expand Down Expand Up @@ -1325,7 +1358,6 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va
return AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
instance.Member(Field.Name),

value);
}

Expand Down Expand Up @@ -1357,20 +1389,26 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va
public FieldAccessorDescription GetGetterFieldDescription()
{
if (IsGettableField || IsGettableProperty) return null;
return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, GetterFieldName, LibraryTypes, false);
return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, GetterFieldName, LibraryTypes, false,
IsPrimaryConstructorParameter && IsProperty);
}

public FieldAccessorDescription GetSetterFieldDescription()
{
if (IsSettableField || IsSettableProperty) return null;
return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, SetterFieldName, LibraryTypes, true);
return GetFieldAccessor(ContainingType, TypeSyntax, MemberName, SetterFieldName, LibraryTypes, true,
IsPrimaryConstructorParameter && IsProperty);
}

public static FieldAccessorDescription GetFieldAccessor(INamedTypeSymbol containingType, TypeSyntax fieldType, string fieldName, string accessorName, LibraryTypes library, bool setter)
public static FieldAccessorDescription GetFieldAccessor(INamedTypeSymbol containingType, TypeSyntax fieldType, string fieldName, string accessorName, LibraryTypes library, bool setter, bool useUnsafeAccessor = false)
{
var valueType = containingType.IsValueType;
var containingTypeSyntax = containingType.ToTypeSyntax();

if (useUnsafeAccessor)
return new(containingTypeSyntax, fieldType, fieldName, accessorName);

var valueType = containingType.IsValueType;

var delegateType = (setter ? (valueType ? library.ValueTypeSetter_2 : library.Action_2) : (valueType ? library.ValueTypeGetter_2 : library.Func_2))
.ToTypeSyntax(containingTypeSyntax, fieldType);

Expand All @@ -1381,7 +1419,8 @@ public static FieldAccessorDescription GetFieldAccessor(INamedTypeSymbol contain
InvocationExpression(fieldAccessorUtility.Member(accessorMethod))
.AddArgumentListArguments(Argument(TypeOfExpression(containingTypeSyntax)), Argument(fieldName.GetLiteralExpression())));

return new(delegateType, accessorName, accessorInvoke);
// Existing case, accessor is the field in both cases
return new(containingTypeSyntax, delegateType, accessorName, accessorName, accessorInvoke);
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion test/Misc/TestSerializerExternalModels/Models.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
namespace UnitTests.SerializerExternalModels;

[GenerateSerializer]
public record struct Person2External(int Age, string Name)
public record struct Person2ExternalStruct(int Age, string Name)
{
[Id(0)]
public string FavouriteColor { get; set; }

[Id(1)]
public string StarSign { get; set; }
}

#if NET6_0_OR_GREATER
[GenerateSerializer]
public record Person2External(int Age, string Name)
{
[Id(0)]
public string FavouriteColor { get; set; }

[Id(1)]
public string StarSign { get; set; }
}
#endif
24 changes: 24 additions & 0 deletions test/Orleans.Serialization.UnitTests/GeneratedSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
using UnitTests.SerializerExternalModels;
using Orleans;

[assembly: GenerateCodeForDeclaringAssembly(typeof(Person2ExternalStruct))]
#if NET6_0_OR_GREATER
[assembly: GenerateCodeForDeclaringAssembly(typeof(Person2External))]
#endif

namespace Orleans.Serialization.UnitTests;

Expand Down Expand Up @@ -130,6 +133,25 @@ public void GeneratedRecordWithPCtorSerializersRoundTripThroughCodec()
Assert.Equal(original.StarSign, result.StarSign);
}

[Fact]
public void GeneratedLibExternalRecordStructWithPCtorSerializersRoundTripThroughCodec()
{
var original = new Person2ExternalStruct(2, "harry")
{
FavouriteColor = "redborine",
StarSign = "Aquaricorn"
};

var result = RoundTripThroughCodec(original);

Assert.Equal(original.Age, result.Age);
Assert.Equal(original.Name, result.Name);
Assert.Equal(original.FavouriteColor, result.FavouriteColor);
Assert.Equal(original.StarSign, result.StarSign);
}

#if NET6_0_OR_GREATER

[Fact]
public void GeneratedLibExternalRecordWithPCtorSerializersRoundTripThroughCodec()
{
Expand All @@ -147,6 +169,8 @@ public void GeneratedLibExternalRecordWithPCtorSerializersRoundTripThroughCodec(
Assert.Equal(original.StarSign, result.StarSign);
}

#endif

#if NET6_0_OR_GREATER
[Fact]
public void RequiredMembersAreSupported()
Expand Down

0 comments on commit f85d424

Please sign in to comment.