Skip to content

Commit

Permalink
Merge pull request #408 from sjrd/fix-parametric-value-class-erasure
Browse files Browse the repository at this point in the history
Fix #405: Completely overhaul erasure of value classes.
  • Loading branch information
sjrd authored Dec 1, 2023
2 parents b5c32dd + 912dabe commit 06ffbf0
Show file tree
Hide file tree
Showing 10 changed files with 375 additions and 108 deletions.
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ lazy val tastyQuery =
import com.typesafe.tools.mima.core.*
Seq(
// private, not an issue
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass"),
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass$"),
ProblemFilters.exclude[MissingClassProblem]("tastyquery.TypeOps$TypeFold"),
// private[tastyquery], not an issue
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Signatures#Signature.toSigName"),
// Everything in tastyquery.reader is private[tastyquery] at most
ProblemFilters.exclude[Problem]("tastyquery.reader.*"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,15 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
lazy val CharClass = scalaPackage.requiredClass("Char")
lazy val UnitClass = scalaPackage.requiredClass("Unit")

private[tastyquery] lazy val BoxedBooleanClass = javaLangPackage.requiredClass("Boolean")
private[tastyquery] lazy val BoxedCharClass = javaLangPackage.requiredClass("Character")
private[tastyquery] lazy val BoxedByteClass = javaLangPackage.requiredClass("Byte")
private[tastyquery] lazy val BoxedShortClass = javaLangPackage.requiredClass("Short")
private[tastyquery] lazy val BoxedIntClass = javaLangPackage.requiredClass("Integer")
private[tastyquery] lazy val BoxedLongClass = javaLangPackage.requiredClass("Long")
private[tastyquery] lazy val BoxedFloatClass = javaLangPackage.requiredClass("Float")
private[tastyquery] lazy val BoxedDoubleClass = javaLangPackage.requiredClass("Double")

lazy val StringClass = javaLangPackage.requiredClass("String")

lazy val ProductClass = scalaPackage.requiredClass("Product")
Expand Down
244 changes: 184 additions & 60 deletions tasty-query/shared/src/main/scala/tastyquery/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ import tastyquery.Types.*
import tastyquery.Types.ErasedTypeRef.*

private[tastyquery] object Erasure:
// TODO: improve this to match dotty:
// - use correct type erasure algorithm from Scala 3, with specialisations
// for Java types and Scala 2 types (i.e. varargs, value-classes)

@deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1")
def erase(tpe: Type)(using Context): ErasedTypeRef =
erase(tpe, SourceLanguage.Scala3)
Expand All @@ -27,44 +23,41 @@ private[tastyquery] object Erasure:
finishErase(preErase(tpe, keepUnit))
end erase

/** First pass of erasure, where some special types are preserved as is.
private[tastyquery] def eraseForSigName(tpe: Type, language: SourceLanguage, keepUnit: Boolean)(
using Context
): ErasedTypeRef =
given SourceLanguage = language

val patchedPreErased = preErase(tpe, keepUnit) match
case ArrayTypeRef(ClassRef(cls), dimensions) if cls.isDerivedValueClass =>
// Hack! dotc's `sigName` does *not* correspond to erasure in this case!
val patchedBase =
if cls.typeParams.isEmpty then preEraseMonoValueClass(cls)
else preErasePolyValueClass(cls, cls.typeParams.map(_.localRef))
patchedBase.underlying.multiArrayOf(dimensions)
case typeRef =>
typeRef

finishErase(patchedPreErased)
end eraseForSigName

private final case class ErasedValueClass(valueClass: ClassSymbol, underlying: ErasedTypeRef)

private type PreErasedTypeRef = ErasedTypeRef | ErasedValueClass

/** First pass of erasure, where some special types are preserved as is,
* and where value classes become `ErasedValueClass`es.
*
* In particular, `Any` is preserved as `Any`, instead of becoming
* `java.lang.Object`.
*/
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): ErasedTypeRef =
def arrayOfBounds(bounds: TypeBounds): ErasedTypeRef =
preErase(bounds.high, keepUnit = false) match
case ClassRef(cls) if cls.isAny || cls.isAnyVal =>
ClassRef(defn.ObjectClass)
case typeRef =>
typeRef.arrayOf()

def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef = tpe match
case tpe: AppliedType =>
tpe.tycon match
case TypeRef.OfClass(cls) =>
if cls.isArray then
val List(targ) = tpe.args: @unchecked
arrayOf(targ).arrayOf()
else ClassRef(cls).arrayOf()
case _ =>
arrayOf(tpe.translucentSuperType)
case TypeRef.OfClass(cls) =>
if cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass).arrayOf()
else ClassRef(cls).arrayOf()
case tpe: TypeRef =>
tpe.optSymbol match
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
arrayOf(tpe.translucentSuperType)
case _ =>
tpe.bounds match
case bounds: AbstractTypeBounds => arrayOfBounds(bounds)
case TypeAlias(alias) => arrayOf(alias)
case tpe: TypeParamRef => arrayOfBounds(tpe.bounds)
case tpe: Type => preErase(tpe, keepUnit = false).arrayOf()
case tpe: WildcardTypeArg => arrayOfBounds(tpe.bounds)
end arrayOf
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): PreErasedTypeRef =
def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef =
if isGenericArrayElement(tpe) then ClassRef(defn.ObjectClass)
else
preErase(tpe.highIfWildcard, keepUnit = false) match
case base: ErasedTypeRef => base.arrayOf()
case ErasedValueClass(valueClass, _) => ClassRef(valueClass).arrayOf()

tpe match
case tpe: AppliedType =>
Expand All @@ -73,11 +66,13 @@ private[tastyquery] object Erasure:
if cls.isArray then
val List(targ) = tpe.args: @unchecked
arrayOf(targ)
else if cls.isDerivedValueClass then preErasePolyValueClass(cls, tpe.args)
else ClassRef(cls)
case _ =>
preErase(tpe.translucentSuperType, keepUnit)
case TypeRef.OfClass(cls) =>
if !keepUnit && cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass)
else if cls.isDerivedValueClass then preEraseMonoValueClass(cls)
else ClassRef(cls)
case tpe: TypeRef =>
preErase(tpe.translucentSuperType, keepUnit)
Expand All @@ -90,7 +85,10 @@ private[tastyquery] object Erasure:
case Some(reduced) => preErase(reduced, keepUnit)
case None => preErase(tpe.bound, keepUnit)
case tpe: OrType =>
erasedLub(preErase(tpe.first, keepUnit = false), preErase(tpe.second, keepUnit = false))
erasedLub(
finishErase(preErase(tpe.first, keepUnit = false)),
finishErase(preErase(tpe.second, keepUnit = false))
)
case tpe: AndType =>
summon[SourceLanguage] match
case SourceLanguage.Java =>
Expand Down Expand Up @@ -120,29 +118,157 @@ private[tastyquery] object Erasure:
throw IllegalArgumentException(s"Unexpected type in erasure: $tpe")
end preErase

private def finishErase(typeRef: ErasedTypeRef)(using Context): ErasedTypeRef =
private def finishErase(typeRef: PreErasedTypeRef)(using Context, SourceLanguage): ErasedTypeRef =
typeRef match
case ClassRef(cls) =>
if cls.isDerivedValueClass then finishEraseValueClass(cls)
else cls.erasure
case ArrayTypeRef(ClassRef(cls), dimensions) =>
ArrayTypeRef(cls.erasure, dimensions)
case ClassRef(cls) => cls.erasure
case ArrayTypeRef(ClassRef(cls), dimensions) => ArrayTypeRef(cls.erasure, dimensions)
case ErasedValueClass(_, underlying) => finishErase(underlying)
end finishErase

private def finishEraseValueClass(cls: ClassSymbol)(using Context): ErasedTypeRef =
private def preEraseMonoValueClass(cls: ClassSymbol)(using Context, SourceLanguage): ErasedValueClass =
val ctor = cls.findNonOverloadedDecl(nme.Constructor)

val underlying = ctor.declaredType match
case tpe: MethodType if tpe.paramNames.sizeIs == 1 =>
tpe.paramTypes.head
case _ =>
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")

// The underlying of value classes are never value classes themselves (by language spec)
val erasedUnderlying = preErase(underlying, keepUnit = false).asInstanceOf[ErasedTypeRef]

ErasedValueClass(cls, erasedUnderlying)
end preEraseMonoValueClass

private def preErasePolyValueClass(cls: ClassSymbol, targs: List[TypeOrWildcard])(
using Context,
SourceLanguage
): ErasedValueClass =
val ctor = cls.findNonOverloadedDecl(nme.Constructor)

def illegalConstructorType(): Nothing =
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")

def ctorParamType(tpe: TypeOrMethodic): Type = tpe match
case tpe: MethodType if tpe.paramTypes.sizeIs == 1 => tpe.paramTypes.head
case tpe: MethodType => illegalConstructorType()
case tpe: PolyType => ctorParamType(tpe.resultType)
case tpe: Type => illegalConstructorType()
case _ => illegalConstructorType()

val ctorPolyType = ctor.declaredType match
case tpe: PolyType => tpe
case _ => illegalConstructorType()

val genericUnderlying = ctorParamType(ctorPolyType.resultType)
val specializedUnderlying = ctorParamType(ctorPolyType.instantiate(targs))

// The underlying of value classes are never value classes themselves (by language spec)
val erasedGenericUnderlying = preErase(genericUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]
val erasedSpecializedUnderlying = preErase(specializedUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]

erase(ctorParamType(ctor.declaredType), ctor.sourceLanguage)
end finishEraseValueClass
def isPrimitive(typeRef: ErasedTypeRef): Boolean = typeRef match
case ClassRef(cls) => cls.isPrimitiveValueClass
case _: ArrayTypeRef => false

/* Ideally, we would just use `erasedSpecializedUnderlying` as the erasure of `tp`.
* However, there are two special cases for polymorphic value classes, which
* historically come from Scala 2:
*
* - Given `class Foo[A](x: A) extends AnyVal`, `Foo[X]` should erase like
* `X`, except if its a primitive in which case it erases to the boxed
* version of this primitive.
* - Given `class Bar[A](x: Array[A]) extends AnyVal`, `Bar[X]` will be
* erased like `Array[A]` as seen from its definition site, no matter
* the `X` (same if `A` is bounded).
*/
val erasedValueClass =
if isPrimitive(erasedSpecializedUnderlying) && !isPrimitive(erasedGenericUnderlying) then
ClassRef(erasedSpecializedUnderlying.asInstanceOf[ClassRef].cls.boxedClass)
else if genericUnderlying.baseType(defn.ArrayClass).isDefined then erasedGenericUnderlying
else erasedSpecializedUnderlying

ErasedValueClass(cls, erasedValueClass)
end preErasePolyValueClass

/** Is `Array[tp]` a generic Array that needs to be erased to `Object`?
* This is true if among the subtypes of `Array[tp]` there is either:
* - both a reference array type and a primitive array type
* (e.g. `Array[_ <: Int | String]`, `Array[_ <: Any]`)
* - or two different primitive array types (e.g. `Array[_ <: Int | Double]`)
* In both cases the erased lub of those array types on the JVM is `Object`.
*
* In addition, if `isScala2` is true, we mimic the Scala 2 erasure rules and
* also return true for element types upper-bounded by a non-reference type
* such as in `Array[_ <: Int]` or `Array[_ <: UniversalTrait]`.
*/
private def isGenericArrayElement(tp: TypeOrWildcard)(using Context, SourceLanguage): Boolean =
/** A symbol that represents the sort of JVM array that values of type `tp` can be stored in:
* - If we can always store such values in a reference array, return `j.l.Object`.
* - If we can always store them in a specific primitive array, return the corresponding primitive class.
* - Otherwise, return `None`.
*/
def arrayUpperBound(tp: Type): Option[ClassSymbol] = tp.dealias match
case TypeRef.OfClass(cls) =>
def isScala2SpecialCase: Boolean =
summon[SourceLanguage] == SourceLanguage.Scala2
&& !cls.isNull
&& !cls.isSubClass(defn.ObjectClass)

// Only a few classes have both primitives and references as subclasses.
if cls.isAny || cls.isAnyVal || cls.isMatchable || cls.isSingleton || isScala2SpecialCase then None
else if cls.isPrimitiveValueClass then Some(cls)
else
// Derived value classes in arrays are always boxed, so they end up here as well
Some(defn.ObjectClass)

case tp: TypeProxy =>
arrayUpperBound(tp.translucentSuperType)
case tp: AndType =>
arrayUpperBound(tp.first).orElse(arrayUpperBound(tp.second))
case tp: OrType =>
val firstBound = arrayUpperBound(tp.first)
val secondBound = arrayUpperBound(tp.first)
if firstBound == secondBound then firstBound
else None
case _: NothingType | _: AnyKindType | _: TypeLambda =>
None
case tp: CustomTransientGroundType =>
throw IllegalArgumentException(s"Unexpected transient type: $tp")
end arrayUpperBound

/** Can one of the JVM Array type store all possible values of type `tp`? */
def fitsInJVMArray(tp: Type): Boolean = arrayUpperBound(tp).isDefined

tp match
case tp: WildcardTypeArg =>
!fitsInJVMArray(tp.bounds.high)

case tp: Type =>
tp.dealias match
case tp: TypeRef =>
tp.optSymbol match
case Some(cls: ClassSymbol) =>
false
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
isGenericArrayElement(tp.translucentSuperType)
case _ =>
tp.bounds match
case TypeAlias(alias) => isGenericArrayElement(alias)
case AbstractTypeBounds(_, high) => !fitsInJVMArray(high)
case tp: TypeParamRef =>
!fitsInJVMArray(tp)
case tp: MatchType =>
val cases = tp.cases
cases.nonEmpty && !fitsInJVMArray(cases.map(_.result).reduce(OrType(_, _)))
case tp: TypeProxy =>
isGenericArrayElement(tp.translucentSuperType)
case tp: AndType =>
isGenericArrayElement(tp.first) && isGenericArrayElement(tp.second)
case tp: OrType =>
isGenericArrayElement(tp.first) || isGenericArrayElement(tp.second)
case _: NothingType | _: AnyKindType | _: TypeLambda =>
false
case tp: CustomTransientGroundType =>
throw IllegalArgumentException(s"Unexpected transient type: $tp")
end isGenericArrayElement

/** The erased least upper bound of two erased types is computed as follows.
*
Expand Down Expand Up @@ -224,7 +350,7 @@ private[tastyquery] object Erasure:
* - Associativity and commutativity, because this method acts as the minimum
* of the total order induced by `compareErasedGlb`.
*/
private def erasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): ErasedTypeRef =
private def erasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): PreErasedTypeRef =
if compareErasedGlb(tp1, tp2) <= 0 then tp1
else tp2

Expand All @@ -248,7 +374,7 @@ private[tastyquery] object Erasure:
*
* @see erasedGlb
*/
private def compareErasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): Int =
private def compareErasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): Int =
def compareClasses(cls1: ClassSymbol, cls2: ClassSymbol): Int =
if cls1.isSubClass(cls2) then -1
else if cls2.isSubClass(cls1) then 1
Expand All @@ -260,13 +386,11 @@ private[tastyquery] object Erasure:
// fast path
0

case (ClassRef(cls1), _) if cls1.isDerivedValueClass =>
tp2 match
case ClassRef(cls2) if cls2.isDerivedValueClass =>
compareClasses(cls1, cls2)
case _ =>
-1
case (_, ClassRef(cls2)) if cls2.isDerivedValueClass =>
case (ErasedValueClass(cls1, _), ErasedValueClass(cls2, _)) =>
compareClasses(cls1, cls2)
case (ErasedValueClass(cls1, _), _) =>
-1
case (_, ErasedValueClass(cls2, _)) =>
1

case (tp1: ArrayTypeRef, tp2: ArrayTypeRef) =>
Expand Down
Loading

0 comments on commit 06ffbf0

Please sign in to comment.