From 912dabe6b188958be7913895eb077331b506324c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Doeraene?= Date: Thu, 30 Nov 2023 15:31:20 +0100 Subject: [PATCH] Fix #405: Completely overhaul erasure of value classes. Including when they contain arrays or are contained in arrays. --- build.sbt | 4 + .../main/scala/tastyquery/Definitions.scala | 9 + .../src/main/scala/tastyquery/Erasure.scala | 244 +++++++++++++----- .../main/scala/tastyquery/Signatures.scala | 30 ++- .../src/main/scala/tastyquery/Symbols.scala | 72 ++++-- .../src/main/scala/tastyquery/Types.scala | 25 +- .../scala/tastyquery/SignatureSuite.scala | 38 ++- .../main/scala/inheritance/MyArrayOps.scala | 9 + .../src/main/scala/inheritance/MyFlags.scala | 7 + .../ValueClassWithDependentErasure.scala | 45 ++++ 10 files changed, 375 insertions(+), 108 deletions(-) create mode 100644 test-sources/src/main/scala/simple_trees/ValueClassWithDependentErasure.scala diff --git a/build.sbt b/build.sbt index 2395b8ff..6d1b804b 100644 --- a/build.sbt +++ b/build.sbt @@ -114,7 +114,11 @@ lazy val tastyQuery = import com.typesafe.tools.mima.core.* Seq( // private, not an issue + ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass"), + ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass$"), ProblemFilters.exclude[MissingClassProblem]("tastyquery.TypeOps$TypeFold"), + // private[tastyquery], not an issue + ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Signatures#Signature.toSigName"), // Everything in tastyquery.reader is private[tastyquery] at most ProblemFilters.exclude[Problem]("tastyquery.reader.*"), ) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala index f62168d2..69917e5d 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Definitions.scala @@ -396,6 +396,15 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS lazy val CharClass = scalaPackage.requiredClass("Char") lazy val UnitClass = scalaPackage.requiredClass("Unit") + private[tastyquery] lazy val BoxedBooleanClass = javaLangPackage.requiredClass("Boolean") + private[tastyquery] lazy val BoxedCharClass = javaLangPackage.requiredClass("Character") + private[tastyquery] lazy val BoxedByteClass = javaLangPackage.requiredClass("Byte") + private[tastyquery] lazy val BoxedShortClass = javaLangPackage.requiredClass("Short") + private[tastyquery] lazy val BoxedIntClass = javaLangPackage.requiredClass("Integer") + private[tastyquery] lazy val BoxedLongClass = javaLangPackage.requiredClass("Long") + private[tastyquery] lazy val BoxedFloatClass = javaLangPackage.requiredClass("Float") + private[tastyquery] lazy val BoxedDoubleClass = javaLangPackage.requiredClass("Double") + lazy val StringClass = javaLangPackage.requiredClass("String") lazy val ProductClass = scalaPackage.requiredClass("Product") diff --git a/tasty-query/shared/src/main/scala/tastyquery/Erasure.scala b/tasty-query/shared/src/main/scala/tastyquery/Erasure.scala index d031da85..f5726db3 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Erasure.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Erasure.scala @@ -11,10 +11,6 @@ import tastyquery.Types.* import tastyquery.Types.ErasedTypeRef.* private[tastyquery] object Erasure: - // TODO: improve this to match dotty: - // - use correct type erasure algorithm from Scala 3, with specialisations - // for Java types and Scala 2 types (i.e. varargs, value-classes) - @deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1") def erase(tpe: Type)(using Context): ErasedTypeRef = erase(tpe, SourceLanguage.Scala3) @@ -27,44 +23,41 @@ private[tastyquery] object Erasure: finishErase(preErase(tpe, keepUnit)) end erase - /** First pass of erasure, where some special types are preserved as is. + private[tastyquery] def eraseForSigName(tpe: Type, language: SourceLanguage, keepUnit: Boolean)( + using Context + ): ErasedTypeRef = + given SourceLanguage = language + + val patchedPreErased = preErase(tpe, keepUnit) match + case ArrayTypeRef(ClassRef(cls), dimensions) if cls.isDerivedValueClass => + // Hack! dotc's `sigName` does *not* correspond to erasure in this case! + val patchedBase = + if cls.typeParams.isEmpty then preEraseMonoValueClass(cls) + else preErasePolyValueClass(cls, cls.typeParams.map(_.localRef)) + patchedBase.underlying.multiArrayOf(dimensions) + case typeRef => + typeRef + + finishErase(patchedPreErased) + end eraseForSigName + + private final case class ErasedValueClass(valueClass: ClassSymbol, underlying: ErasedTypeRef) + + private type PreErasedTypeRef = ErasedTypeRef | ErasedValueClass + + /** First pass of erasure, where some special types are preserved as is, + * and where value classes become `ErasedValueClass`es. * * In particular, `Any` is preserved as `Any`, instead of becoming * `java.lang.Object`. */ - private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): ErasedTypeRef = - def arrayOfBounds(bounds: TypeBounds): ErasedTypeRef = - preErase(bounds.high, keepUnit = false) match - case ClassRef(cls) if cls.isAny || cls.isAnyVal => - ClassRef(defn.ObjectClass) - case typeRef => - typeRef.arrayOf() - - def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef = tpe match - case tpe: AppliedType => - tpe.tycon match - case TypeRef.OfClass(cls) => - if cls.isArray then - val List(targ) = tpe.args: @unchecked - arrayOf(targ).arrayOf() - else ClassRef(cls).arrayOf() - case _ => - arrayOf(tpe.translucentSuperType) - case TypeRef.OfClass(cls) => - if cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass).arrayOf() - else ClassRef(cls).arrayOf() - case tpe: TypeRef => - tpe.optSymbol match - case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias => - arrayOf(tpe.translucentSuperType) - case _ => - tpe.bounds match - case bounds: AbstractTypeBounds => arrayOfBounds(bounds) - case TypeAlias(alias) => arrayOf(alias) - case tpe: TypeParamRef => arrayOfBounds(tpe.bounds) - case tpe: Type => preErase(tpe, keepUnit = false).arrayOf() - case tpe: WildcardTypeArg => arrayOfBounds(tpe.bounds) - end arrayOf + private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): PreErasedTypeRef = + def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef = + if isGenericArrayElement(tpe) then ClassRef(defn.ObjectClass) + else + preErase(tpe.highIfWildcard, keepUnit = false) match + case base: ErasedTypeRef => base.arrayOf() + case ErasedValueClass(valueClass, _) => ClassRef(valueClass).arrayOf() tpe match case tpe: AppliedType => @@ -73,11 +66,13 @@ private[tastyquery] object Erasure: if cls.isArray then val List(targ) = tpe.args: @unchecked arrayOf(targ) + else if cls.isDerivedValueClass then preErasePolyValueClass(cls, tpe.args) else ClassRef(cls) case _ => preErase(tpe.translucentSuperType, keepUnit) case TypeRef.OfClass(cls) => if !keepUnit && cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass) + else if cls.isDerivedValueClass then preEraseMonoValueClass(cls) else ClassRef(cls) case tpe: TypeRef => preErase(tpe.translucentSuperType, keepUnit) @@ -90,7 +85,10 @@ private[tastyquery] object Erasure: case Some(reduced) => preErase(reduced, keepUnit) case None => preErase(tpe.bound, keepUnit) case tpe: OrType => - erasedLub(preErase(tpe.first, keepUnit = false), preErase(tpe.second, keepUnit = false)) + erasedLub( + finishErase(preErase(tpe.first, keepUnit = false)), + finishErase(preErase(tpe.second, keepUnit = false)) + ) case tpe: AndType => summon[SourceLanguage] match case SourceLanguage.Java => @@ -120,16 +118,32 @@ private[tastyquery] object Erasure: throw IllegalArgumentException(s"Unexpected type in erasure: $tpe") end preErase - private def finishErase(typeRef: ErasedTypeRef)(using Context): ErasedTypeRef = + private def finishErase(typeRef: PreErasedTypeRef)(using Context, SourceLanguage): ErasedTypeRef = typeRef match - case ClassRef(cls) => - if cls.isDerivedValueClass then finishEraseValueClass(cls) - else cls.erasure - case ArrayTypeRef(ClassRef(cls), dimensions) => - ArrayTypeRef(cls.erasure, dimensions) + case ClassRef(cls) => cls.erasure + case ArrayTypeRef(ClassRef(cls), dimensions) => ArrayTypeRef(cls.erasure, dimensions) + case ErasedValueClass(_, underlying) => finishErase(underlying) end finishErase - private def finishEraseValueClass(cls: ClassSymbol)(using Context): ErasedTypeRef = + private def preEraseMonoValueClass(cls: ClassSymbol)(using Context, SourceLanguage): ErasedValueClass = + val ctor = cls.findNonOverloadedDecl(nme.Constructor) + + val underlying = ctor.declaredType match + case tpe: MethodType if tpe.paramNames.sizeIs == 1 => + tpe.paramTypes.head + case _ => + throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}") + + // The underlying of value classes are never value classes themselves (by language spec) + val erasedUnderlying = preErase(underlying, keepUnit = false).asInstanceOf[ErasedTypeRef] + + ErasedValueClass(cls, erasedUnderlying) + end preEraseMonoValueClass + + private def preErasePolyValueClass(cls: ClassSymbol, targs: List[TypeOrWildcard])( + using Context, + SourceLanguage + ): ErasedValueClass = val ctor = cls.findNonOverloadedDecl(nme.Constructor) def illegalConstructorType(): Nothing = @@ -137,12 +151,124 @@ private[tastyquery] object Erasure: def ctorParamType(tpe: TypeOrMethodic): Type = tpe match case tpe: MethodType if tpe.paramTypes.sizeIs == 1 => tpe.paramTypes.head - case tpe: MethodType => illegalConstructorType() - case tpe: PolyType => ctorParamType(tpe.resultType) - case tpe: Type => illegalConstructorType() + case _ => illegalConstructorType() + + val ctorPolyType = ctor.declaredType match + case tpe: PolyType => tpe + case _ => illegalConstructorType() + + val genericUnderlying = ctorParamType(ctorPolyType.resultType) + val specializedUnderlying = ctorParamType(ctorPolyType.instantiate(targs)) + + // The underlying of value classes are never value classes themselves (by language spec) + val erasedGenericUnderlying = preErase(genericUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef] + val erasedSpecializedUnderlying = preErase(specializedUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef] - erase(ctorParamType(ctor.declaredType), ctor.sourceLanguage) - end finishEraseValueClass + def isPrimitive(typeRef: ErasedTypeRef): Boolean = typeRef match + case ClassRef(cls) => cls.isPrimitiveValueClass + case _: ArrayTypeRef => false + + /* Ideally, we would just use `erasedSpecializedUnderlying` as the erasure of `tp`. + * However, there are two special cases for polymorphic value classes, which + * historically come from Scala 2: + * + * - Given `class Foo[A](x: A) extends AnyVal`, `Foo[X]` should erase like + * `X`, except if its a primitive in which case it erases to the boxed + * version of this primitive. + * - Given `class Bar[A](x: Array[A]) extends AnyVal`, `Bar[X]` will be + * erased like `Array[A]` as seen from its definition site, no matter + * the `X` (same if `A` is bounded). + */ + val erasedValueClass = + if isPrimitive(erasedSpecializedUnderlying) && !isPrimitive(erasedGenericUnderlying) then + ClassRef(erasedSpecializedUnderlying.asInstanceOf[ClassRef].cls.boxedClass) + else if genericUnderlying.baseType(defn.ArrayClass).isDefined then erasedGenericUnderlying + else erasedSpecializedUnderlying + + ErasedValueClass(cls, erasedValueClass) + end preErasePolyValueClass + + /** Is `Array[tp]` a generic Array that needs to be erased to `Object`? + * This is true if among the subtypes of `Array[tp]` there is either: + * - both a reference array type and a primitive array type + * (e.g. `Array[_ <: Int | String]`, `Array[_ <: Any]`) + * - or two different primitive array types (e.g. `Array[_ <: Int | Double]`) + * In both cases the erased lub of those array types on the JVM is `Object`. + * + * In addition, if `isScala2` is true, we mimic the Scala 2 erasure rules and + * also return true for element types upper-bounded by a non-reference type + * such as in `Array[_ <: Int]` or `Array[_ <: UniversalTrait]`. + */ + private def isGenericArrayElement(tp: TypeOrWildcard)(using Context, SourceLanguage): Boolean = + /** A symbol that represents the sort of JVM array that values of type `tp` can be stored in: + * - If we can always store such values in a reference array, return `j.l.Object`. + * - If we can always store them in a specific primitive array, return the corresponding primitive class. + * - Otherwise, return `None`. + */ + def arrayUpperBound(tp: Type): Option[ClassSymbol] = tp.dealias match + case TypeRef.OfClass(cls) => + def isScala2SpecialCase: Boolean = + summon[SourceLanguage] == SourceLanguage.Scala2 + && !cls.isNull + && !cls.isSubClass(defn.ObjectClass) + + // Only a few classes have both primitives and references as subclasses. + if cls.isAny || cls.isAnyVal || cls.isMatchable || cls.isSingleton || isScala2SpecialCase then None + else if cls.isPrimitiveValueClass then Some(cls) + else + // Derived value classes in arrays are always boxed, so they end up here as well + Some(defn.ObjectClass) + + case tp: TypeProxy => + arrayUpperBound(tp.translucentSuperType) + case tp: AndType => + arrayUpperBound(tp.first).orElse(arrayUpperBound(tp.second)) + case tp: OrType => + val firstBound = arrayUpperBound(tp.first) + val secondBound = arrayUpperBound(tp.first) + if firstBound == secondBound then firstBound + else None + case _: NothingType | _: AnyKindType | _: TypeLambda => + None + case tp: CustomTransientGroundType => + throw IllegalArgumentException(s"Unexpected transient type: $tp") + end arrayUpperBound + + /** Can one of the JVM Array type store all possible values of type `tp`? */ + def fitsInJVMArray(tp: Type): Boolean = arrayUpperBound(tp).isDefined + + tp match + case tp: WildcardTypeArg => + !fitsInJVMArray(tp.bounds.high) + + case tp: Type => + tp.dealias match + case tp: TypeRef => + tp.optSymbol match + case Some(cls: ClassSymbol) => + false + case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias => + isGenericArrayElement(tp.translucentSuperType) + case _ => + tp.bounds match + case TypeAlias(alias) => isGenericArrayElement(alias) + case AbstractTypeBounds(_, high) => !fitsInJVMArray(high) + case tp: TypeParamRef => + !fitsInJVMArray(tp) + case tp: MatchType => + val cases = tp.cases + cases.nonEmpty && !fitsInJVMArray(cases.map(_.result).reduce(OrType(_, _))) + case tp: TypeProxy => + isGenericArrayElement(tp.translucentSuperType) + case tp: AndType => + isGenericArrayElement(tp.first) && isGenericArrayElement(tp.second) + case tp: OrType => + isGenericArrayElement(tp.first) || isGenericArrayElement(tp.second) + case _: NothingType | _: AnyKindType | _: TypeLambda => + false + case tp: CustomTransientGroundType => + throw IllegalArgumentException(s"Unexpected transient type: $tp") + end isGenericArrayElement /** The erased least upper bound of two erased types is computed as follows. * @@ -224,7 +350,7 @@ private[tastyquery] object Erasure: * - Associativity and commutativity, because this method acts as the minimum * of the total order induced by `compareErasedGlb`. */ - private def erasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): ErasedTypeRef = + private def erasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): PreErasedTypeRef = if compareErasedGlb(tp1, tp2) <= 0 then tp1 else tp2 @@ -248,7 +374,7 @@ private[tastyquery] object Erasure: * * @see erasedGlb */ - private def compareErasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): Int = + private def compareErasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): Int = def compareClasses(cls1: ClassSymbol, cls2: ClassSymbol): Int = if cls1.isSubClass(cls2) then -1 else if cls2.isSubClass(cls1) then 1 @@ -260,13 +386,11 @@ private[tastyquery] object Erasure: // fast path 0 - case (ClassRef(cls1), _) if cls1.isDerivedValueClass => - tp2 match - case ClassRef(cls2) if cls2.isDerivedValueClass => - compareClasses(cls1, cls2) - case _ => - -1 - case (_, ClassRef(cls2)) if cls2.isDerivedValueClass => + case (ErasedValueClass(cls1, _), ErasedValueClass(cls2, _)) => + compareClasses(cls1, cls2) + case (ErasedValueClass(cls1, _), _) => + -1 + case (_, ErasedValueClass(cls2, _)) => 1 case (tp1: ArrayTypeRef, tp2: ArrayTypeRef) => diff --git a/tasty-query/shared/src/main/scala/tastyquery/Signatures.scala b/tasty-query/shared/src/main/scala/tastyquery/Signatures.scala index c574076b..f8521c90 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Signatures.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Signatures.scala @@ -4,6 +4,7 @@ import tastyquery.Contexts.* import tastyquery.Names.* import tastyquery.Symbols.* import tastyquery.Types.* +import tastyquery.Types.ErasedTypeRef.* object Signatures: enum ParamSig: @@ -27,6 +28,24 @@ object Signatures: end Signature object Signature { + private def sigName(tpe: Type, language: SourceLanguage, keepUnit: Boolean)(using Context): SignatureName = + toSigName(Erasure.eraseForSigName(tpe, language, keepUnit)) + + private[tastyquery] def toSigName(typeRef: ErasedTypeRef): SignatureName = typeRef match + case ClassRef(cls) => + cls.signatureName + + case ArrayTypeRef(base, dimensions) => + val suffix = "[]" * dimensions + val baseName = base.cls.signatureName + val suffixedLast = baseName.items.last match + case ObjectClassName(baseModuleName) => + baseModuleName.append(suffix).withObjectSuffix + case last: SimpleName => + last.append(suffix) + SignatureName(baseName.items.init :+ suffixedLast) + end toSigName + private[tastyquery] def fromType( info: TypeOrMethodic, language: SourceLanguage, @@ -35,14 +54,15 @@ object Signatures: def rec(info: TypeOrMethodic, acc: List[ParamSig]): Signature = info match { case info: MethodType => - val erased = info.paramTypes.map(tpe => ParamSig.Term(ErasedTypeRef.erase(tpe, language).toSigFullName)) - rec(info.resultType, acc ::: erased) + val paramSigs = info.paramTypes.map(tpe => ParamSig.Term(sigName(tpe, language, keepUnit = false))) + rec(info.resultType, acc ::: paramSigs) case info: PolyType => - rec(info.resultType, acc ::: ParamSig.TypeLen(info.paramTypeBounds.length) :: Nil) + val typeLenSig = ParamSig.TypeLen(info.paramTypeBounds.length) + rec(info.resultType, acc ::: typeLenSig :: Nil) case tpe: Type => val retType = optCtorReturn.map(_.appliedRefInsideThis).getOrElse(tpe) - val erasedRetType = ErasedTypeRef.erase(retType, language, keepUnit = true) - Signature(acc, erasedRetType.toSigFullName) + val resSig = sigName(retType, language, keepUnit = true) + Signature(acc, resSig) } rec(info, Nil) diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index fcdae60a..e82b5f03 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -1005,11 +1005,12 @@ object Symbols { specialKind == SpecialKind.None && isValueClass def isPrimitiveValueClass: Boolean = - specialKind == SpecialKind.Unit || specialKind == SpecialKind.NonUnitPrimitive + specialKind >= SpecialKind.FirstPrimitive && specialKind <= SpecialKind.LastPrimitive def isTupleNClass: Boolean = specialKind == SpecialKind.TupleN private[tastyquery] def isAny: Boolean = specialKind == SpecialKind.Any + private[tastyquery] def isMatchable: Boolean = specialKind == SpecialKind.Matchable private[tastyquery] def isObject: Boolean = specialKind == SpecialKind.Object private[tastyquery] def isAnyVal: Boolean = specialKind == SpecialKind.AnyVal private[tastyquery] def isUnit: Boolean = specialKind == SpecialKind.Unit @@ -1200,6 +1201,19 @@ object Symbols { ErasedTypeRef.ClassRef(this) end computeErasure + private[tastyquery] final def boxedClass(using Context): ClassSymbol = specialKind match + case SpecialKind.Unit => defn.ErasedBoxedUnitClass + case SpecialKind.Boolean => defn.BoxedBooleanClass + case SpecialKind.Char => defn.BoxedCharClass + case SpecialKind.Byte => defn.BoxedByteClass + case SpecialKind.Short => defn.BoxedShortClass + case SpecialKind.Int => defn.BoxedIntClass + case SpecialKind.Long => defn.BoxedLongClass + case SpecialKind.Float => defn.BoxedFloatClass + case SpecialKind.Double => defn.BoxedDoubleClass + case _ => this + end boxedClass + // DeclaringSymbol implementation private[Symbols] final def addDecl(decl: TermOrTypeSymbol): Unit = @@ -1671,21 +1685,31 @@ object Symbols { inline val Object = 3 inline val AnyVal = 4 inline val Unit = 5 - inline val NonUnitPrimitive = 6 - inline val String = 7 - inline val Null = 8 - inline val Singleton = 9 - inline val Array = 10 - inline val PolyFunction = 11 - inline val Tuple = 12 - inline val NonEmptyTuple = 13 - inline val TupleCons = 14 - inline val EmptyTuple = 15 - inline val FunctionN = 16 - inline val ContextFunctionN = 17 - inline val TupleN = 18 - inline val JavaEnum = 19 - inline val Refinement = 20 + inline val Boolean = 6 + inline val Char = 7 + inline val Byte = 8 + inline val Short = 9 + inline val Int = 10 + inline val Long = 11 + inline val Float = 12 + inline val Double = 13 + inline val String = 14 + inline val Null = 15 + inline val Singleton = 16 + inline val Array = 17 + inline val PolyFunction = 18 + inline val Tuple = 19 + inline val NonEmptyTuple = 20 + inline val TupleCons = 21 + inline val EmptyTuple = 22 + inline val FunctionN = 23 + inline val ContextFunctionN = 24 + inline val TupleN = 25 + inline val JavaEnum = 26 + inline val Refinement = 27 + + inline val FirstPrimitive = Unit + inline val LastPrimitive = Double end SpecialKind private def computeSpecialKind(name: ClassTypeName, owner: Symbol): SpecialKind = @@ -1700,14 +1724,14 @@ object Symbols { case tpnme.Matchable => SpecialKind.Matchable case tpnme.AnyVal => SpecialKind.AnyVal case tpnme.Unit => SpecialKind.Unit - case tpnme.Boolean => SpecialKind.NonUnitPrimitive - case tpnme.Char => SpecialKind.NonUnitPrimitive - case tpnme.Byte => SpecialKind.NonUnitPrimitive - case tpnme.Short => SpecialKind.NonUnitPrimitive - case tpnme.Int => SpecialKind.NonUnitPrimitive - case tpnme.Long => SpecialKind.NonUnitPrimitive - case tpnme.Float => SpecialKind.NonUnitPrimitive - case tpnme.Double => SpecialKind.NonUnitPrimitive + case tpnme.Boolean => SpecialKind.Boolean + case tpnme.Char => SpecialKind.Char + case tpnme.Byte => SpecialKind.Byte + case tpnme.Short => SpecialKind.Short + case tpnme.Int => SpecialKind.Int + case tpnme.Long => SpecialKind.Long + case tpnme.Float => SpecialKind.Float + case tpnme.Double => SpecialKind.Double case tpnme.Null => SpecialKind.Null case tpnme.Singleton => SpecialKind.Singleton case tpnme.Array => SpecialKind.Array diff --git a/tasty-query/shared/src/main/scala/tastyquery/Types.scala b/tasty-query/shared/src/main/scala/tastyquery/Types.scala index 237a751a..a9459a06 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Types.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Types.scala @@ -130,25 +130,16 @@ object Types { case ClassRef(cls) => cls.signatureName.toString() case ArrayTypeRef(base, dimensions) => base.toString() + "[]" * dimensions - def arrayOf(): ArrayTypeRef = this match - case classRef: ClassRef => ArrayTypeRef(classRef, 1) - case ArrayTypeRef(base, dimensions) => ArrayTypeRef(base, dimensions + 1) + def arrayOf(): ArrayTypeRef = multiArrayOf(dims = 1) + + private[tastyquery] def multiArrayOf(dims: Int): ArrayTypeRef = this match + case classRef: ClassRef => ArrayTypeRef(classRef, dims) + case ArrayTypeRef(base, dimensions) => ArrayTypeRef(base, dimensions + dims) /** The `SignatureName` for this `ErasedTypeRef` as found in the `TermSig`s of `Signature`s. */ - def toSigFullName: SignatureName = this match - case ClassRef(cls) => - cls.signatureName - - case ArrayTypeRef(base, dimensions) => - val suffix = "[]" * dimensions - val baseName = base.cls.signatureName - val suffixedLast = baseName.items.last match - case ObjectClassName(baseModuleName) => - baseModuleName.append(suffix).withObjectSuffix - case last: SimpleName => - last.append(suffix) - SignatureName(baseName.items.init :+ suffixedLast) - end toSigFullName + @deprecated("is is not a meaningful operation; it does not compute the right SignatureName's", since = "1.1.1") + def toSigFullName: SignatureName = + Signature.toSigName(this) end ErasedTypeRef object ErasedTypeRef: diff --git a/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala index 43f915ca..261e9a4e 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala @@ -282,13 +282,47 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite: testWithContext("value-class-monomorphic-arrayOf") { val MyFlags = ctx.findTopLevelModuleClass("inheritance.MyFlags") val mergeAll = MyFlags.findNonOverloadedDecl(name"mergeAll") - assertSigned(mergeAll, "(inheritance.MyFlags[]):scala.Long") + assertSigned(mergeAll, "(scala.Long[]):scala.Long") } testWithContext("value-class-polymorphic-arrayOf") { val MyArrayOps = ctx.findTopLevelModuleClass("inheritance.MyArrayOps") val arrayOfIntArrayOps = MyArrayOps.findNonOverloadedDecl(name"arrayOfIntArrayOps") - assertSigned(arrayOfIntArrayOps, "(scala.Int[][]):inheritance.MyArrayOps[]") + assertSigned(arrayOfIntArrayOps, "(scala.Int[][]):java.lang.Object[]") + } + + testWithContext("value-class-dependent") { + val ValueClassWithDependentErasureClass = ctx.findTopLevelModuleClass("simple_trees.ValueClassWithDependentErasure") + + val ofGeneric = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofGeneric")) + assertSigned(ofGeneric, "(1,java.lang.Object):java.lang.Object") + + val ofString = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofString")) + assertSigned(ofString, "(java.lang.String):java.lang.String") + + val ofInt = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofInt")) + assertSigned(ofInt, "(java.lang.Integer):scala.Int") + + val arrayOfGeneric = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("arrayOfGeneric")) + assertSigned(arrayOfGeneric, "(1,java.lang.Object[]):java.lang.Object") + + val arrayOfString = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("arrayOfString")) + assertSigned(arrayOfString, "(java.lang.Object[]):java.lang.String") + + val arrayOfInt = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("arrayOfInt")) + assertSigned(arrayOfInt, "(java.lang.Object[]):scala.Int") + + val ofGenericArray = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofGenericArray")) + assertSigned(ofGenericArray, "(1,java.lang.Object):java.lang.Object") + + val ofGenericSeqArray = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofGenericSeqArray")) + assertSigned(ofGenericSeqArray, "(1,scala.collection.immutable.Seq[]):scala.collection.immutable.Seq[]") + + val ofStringArray = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofStringArray")) + assertSigned(ofStringArray, "(java.lang.String[]):java.lang.String[]") + + val ofIntArray = ValueClassWithDependentErasureClass.findNonOverloadedDecl(termName("ofIntArray")) + assertSigned(ofIntArray, "(scala.Int[]):scala.Int[]") } testWithContext("package-ref-from-tasty") { diff --git a/test-sources/src/main/scala/inheritance/MyArrayOps.scala b/test-sources/src/main/scala/inheritance/MyArrayOps.scala index 7aff5034..64fb7003 100644 --- a/test-sources/src/main/scala/inheritance/MyArrayOps.scala +++ b/test-sources/src/main/scala/inheritance/MyArrayOps.scala @@ -7,3 +7,12 @@ object MyArrayOps: def genericArrayOps[T](xs: Array[T]): MyArrayOps[T] = new MyArrayOps(xs) def arrayOfIntArrayOps(xss: Array[Array[Int]]): Array[MyArrayOps[Int]] = xss.map(intArrayOps) +end MyArrayOps + +object MyArrayOpsTest: + def test(): Unit = + MyArrayOps.intArrayOps(Array(1)) + MyArrayOps.genericArrayOps(Array("foo")) + MyArrayOps.arrayOfIntArrayOps(Array(Array(1))) + end test +end MyArrayOpsTest diff --git a/test-sources/src/main/scala/inheritance/MyFlags.scala b/test-sources/src/main/scala/inheritance/MyFlags.scala index 66f1a3fa..0a69e347 100644 --- a/test-sources/src/main/scala/inheritance/MyFlags.scala +++ b/test-sources/src/main/scala/inheritance/MyFlags.scala @@ -7,3 +7,10 @@ object MyFlags: val Private: MyFlags = new MyFlags(1L << 0) def mergeAll(xs: Array[MyFlags]): MyFlags = xs.reduce(_.merge(_)) +end MyFlags + +object MyFlagsTest: + def test(): Unit = + MyFlags.mergeAll(Array(MyFlags.Private)) + end test +end MyFlagsTest diff --git a/test-sources/src/main/scala/simple_trees/ValueClassWithDependentErasure.scala b/test-sources/src/main/scala/simple_trees/ValueClassWithDependentErasure.scala new file mode 100644 index 00000000..08992aa6 --- /dev/null +++ b/test-sources/src/main/scala/simple_trees/ValueClassWithDependentErasure.scala @@ -0,0 +1,45 @@ +package simple_trees + +final class ValueClassWithDependentErasure[T](val value: T) extends AnyVal + +object ValueClassWithDependentErasure: + def ofGeneric[T](vc: ValueClassWithDependentErasure[T]): T = vc.value + + def ofString(vc: ValueClassWithDependentErasure[String]): String = vc.value + + def ofInt(vc: ValueClassWithDependentErasure[Int]): Int = vc.value + + def arrayOfGeneric[T](vcs: Array[ValueClassWithDependentErasure[T]]): T = vcs(0).value + + def arrayOfString(vcs: Array[ValueClassWithDependentErasure[String]]): String = vcs(0).value + + def arrayOfInt(vcs: Array[ValueClassWithDependentErasure[Int]]): Int = vcs(0).value + + def ofGenericArray[T](vc: ValueClassWithDependentErasure[Array[T]]): Array[T] = vc.value + + def ofGenericSeqArray[T](vc: ValueClassWithDependentErasure[Array[? <: Seq[T]]]): Array[? <: Seq[T]] = vc.value + + def ofStringArray(vc: ValueClassWithDependentErasure[Array[String]]): Array[String] = vc.value + + def ofIntArray(vc: ValueClassWithDependentErasure[Array[Int]]): Array[Int] = vc.value +end ValueClassWithDependentErasure + +object ValueClassWithDependentErasureTest: + def test(): Unit = + ValueClassWithDependentErasure.ofGeneric(new ValueClassWithDependentErasure(Some(5))) + ValueClassWithDependentErasure.ofGeneric(new ValueClassWithDependentErasure(5)) + ValueClassWithDependentErasure.ofString(new ValueClassWithDependentErasure("hello")) + ValueClassWithDependentErasure.ofInt(new ValueClassWithDependentErasure(5)) + + ValueClassWithDependentErasure.arrayOfGeneric(Array(new ValueClassWithDependentErasure(Some(5)))) + ValueClassWithDependentErasure.arrayOfGeneric(Array(new ValueClassWithDependentErasure(5))) + ValueClassWithDependentErasure.arrayOfString(Array(new ValueClassWithDependentErasure("hello"))) + ValueClassWithDependentErasure.arrayOfInt(Array(new ValueClassWithDependentErasure(5))) + + ValueClassWithDependentErasure.ofGenericArray(new ValueClassWithDependentErasure(Array(Some(5)))) + ValueClassWithDependentErasure.ofGenericArray(new ValueClassWithDependentErasure(Array(5))) + ValueClassWithDependentErasure.ofGenericSeqArray(new ValueClassWithDependentErasure(Array(List(3)))) + ValueClassWithDependentErasure.ofStringArray(new ValueClassWithDependentErasure(Array("hello"))) + ValueClassWithDependentErasure.ofIntArray(new ValueClassWithDependentErasure(Array(5))) + end test +end ValueClassWithDependentErasureTest