Skip to content

Commit

Permalink
Merge pull request #218 from sjrd/target-name
Browse files Browse the repository at this point in the history
Fix #128: Support @TargetNAME.
  • Loading branch information
bishabosha authored Dec 2, 2022
2 parents 392e2f0 + 44635cd commit af9a13b
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 22 deletions.
99 changes: 99 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Annotations.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,108 @@
package tastyquery

import scala.annotation.tailrec

import tastyquery.Constants.*
import tastyquery.Contexts.*
import tastyquery.Exceptions.*
import tastyquery.Names.*
import tastyquery.Symbols.*
import tastyquery.Trees.*
import tastyquery.Types.*

object Annotations:
final class Annotation(val tree: TermTree):
private var mySymbol: ClassSymbol | Null = null
private var myArguments: List[TermTree] | Null = null

/** The annotation class symbol. */
def symbol(using Context): ClassSymbol =
val local = mySymbol
if local != null then local
else
val computed = computeAnnotSymbol(tree)
mySymbol = computed
computed
end symbol

/** The symbol of the constructor used in the annotation. */
def annotConstructor(using Context): TermSymbol =
computeAnnotConstructor(tree)

/** All the term arguments to the annotation's constructor.
*
* If the constructor has several parameter lists, the arguments are
* flattened in a single list.
*
* `NamedArg`s are not visible with this method. They are replaced by
* their right-hand-side.
*/
def arguments(using Context): List[TermTree] =
val local = myArguments
if local != null then local
else
val computed = computeAnnotArguments(tree)
myArguments = computed
computed
end arguments

def argCount(using Context): Int = arguments.size

def argIfConstant(idx: Int)(using Context): Option[Constant] =
arguments(idx) match
case Literal(constant) => Some(constant)
case _ => None

override def toString(): String = s"Annotation($tree)"
end Annotation

private def computeAnnotSymbol(tree: TermTree)(using Context): ClassSymbol =
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation class in $tree")

@tailrec
def loop(tree: TermTree): ClassSymbol = tree match
case Apply(fun, _) => loop(fun)
case New(tpt) => tpt.toType.classSymbol.getOrElse(invalid())
case Select(qual, _) => loop(qual)
case TypeApply(fun, _) => loop(fun)
case Block(_, expr) => loop(expr)
case _ => invalid()

loop(tree)
end computeAnnotSymbol

private def computeAnnotConstructor(tree: TermTree)(using Context): TermSymbol =
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation constructor in $tree")

@tailrec
def loop(tree: TermTree): TermSymbol = tree match
case Apply(fun, _) => loop(fun)
case tree @ Select(New(tpt), _) => tree.tpe.asInstanceOf[TermRef].symbol
case TypeApply(fun, _) => loop(fun)
case Block(_, expr) => loop(expr)
case _ => invalid()

loop(tree)
end computeAnnotConstructor

private def computeAnnotArguments(tree: TermTree)(using Context): List[TermTree] =
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation arguments in $tree")

@tailrec
def loop(tree: TermTree, tail: List[TermTree]): List[TermTree] = tree match
case Apply(fun, args) => loop(fun, args ::: tail)
case Select(New(tpt), _) => tail
case TypeApply(fun, _) => loop(fun, tail)
case Block(_, expr) => loop(expr, tail)
case New(tpt) => tail // for some ancient TASTy with raw New's
case _ => invalid()

loop(tree, Nil).map {
case NamedArg(_, arg) => arg
case arg => arg
}
end computeAnnotArguments
end Annotations
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
private val javaPackage = RootPackage.getPackageDeclOrCreate(nme.javaPackageName)
val javaLangPackage = javaPackage.getPackageDeclOrCreate(nme.langPackageName)

private val scalaAnnotationPackage =
scalaPackage.getPackageDeclOrCreate(termName("annotation"))
private val scalaCollectionPackage =
scalaPackage.getPackageDeclOrCreate(termName("collection"))
private val scalaCollectionImmutablePackage =
Expand Down Expand Up @@ -199,6 +201,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS

extension (pkg: PackageSymbol)
private def requiredClass(name: String): ClassSymbol = pkg.getDecl(typeName(name)).get.asClass
private def optionalClass(name: String): Option[ClassSymbol] = pkg.getDecl(typeName(name)).map(_.asClass)

lazy val ObjectClass = javaLangPackage.requiredClass("Object")

Expand All @@ -219,6 +222,8 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS

lazy val StringClass = javaLangPackage.requiredClass("String")

private[tastyquery] lazy val targetNameAnnotClass = scalaAnnotationPackage.optionalClass("targetName")

def isPrimitiveValueClass(sym: ClassSymbol): Boolean =
sym == IntClass || sym == LongClass || sym == FloatClass || sym == DoubleClass ||
sym == BooleanClass || sym == ByteClass || sym == ShortClass || sym == CharClass ||
Expand Down
38 changes: 34 additions & 4 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ object Symbols {
case scope: ClassSymbol => scope.hasOverloads(name)
case _ => false

final def hasAnnotation(annotClass: ClassSymbol)(using Context): Boolean =
annotations.exists(_.symbol == annotClass)

final def getAnnotations(annotClass: ClassSymbol)(using Context): List[Annotation] =
annotations.filter(_.symbol == annotClass)

final def getAnnotation(annotClass: ClassSymbol)(using Context): Option[Annotation] =
annotations.find(_.symbol == annotClass)

override def toString: String = {
val kind = this match
case _: PackageSymbol => "package "
Expand Down Expand Up @@ -297,6 +306,7 @@ object Symbols {

// Cache fields
private var mySignature: Option[Signature] | Null = null
private var myTargetName: TermName | Null = null
private var myParamRefss: List[Either[List[TermParamRef], List[TypeParamRef]]] | Null = null

protected override def doCheckCompleted(): Unit =
Expand Down Expand Up @@ -347,11 +357,29 @@ object Symbols {
mySignature = sig
sig

private[tastyquery] final def targetName(using Context): TermName =
val local = myTargetName
if local != null then local
else
val computed = computeTargetName()
myTargetName = computed
computed
end targetName

private def computeTargetName()(using Context): TermName =
if annotations.isEmpty then name
else
defn.targetNameAnnotClass match
case None => name
case Some(targetNameAnnotClass) =>
getAnnotation(targetNameAnnotClass) match
case None => name
case Some(annot) => termName(annot.argIfConstant(0).get.stringValue)
end computeTargetName

/** If this symbol has a `MethodicType`, returns a `SignedName`, otherwise a `Name`. */
final def signedName(using Context): Name =
signature.fold(name) { sig =>
val name = this.name.asSimpleName
val targetName = name // TODO We may have to take `@targetName` into account here, one day
SignedName(name, sig, targetName)
}

Expand Down Expand Up @@ -670,8 +698,10 @@ object Symbols {
myDeclarations.get(overloaded.underlying) match
case Some(overloads) =>
overloads.find {
case decl: TermSymbol => decl.signature.exists(_ == overloaded.sig)
case _ => false
case decl: TermSymbol =>
decl.signature.exists(_ == overloaded.sig) && decl.targetName == overloaded.target
case _ =>
false
}
case None => None
end distinguishOverloaded
Expand Down
14 changes: 13 additions & 1 deletion tasty-query/shared/src/test/scala/tastyquery/ReadTreeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2025,6 +2025,16 @@ class ReadTreeSuite extends RestrictedUnpicklingSuite {
) =>
}

def deprecatedAnnotBothNamedCheck(msg: String, since: String): StructureCheck = {
case Apply(
SimpleAnnotCtorNamed("deprecated"),
List(
NamedArg(SimpleName("message"), Literal(Constant(`msg`))),
NamedArg(SimpleName("since"), Literal(Constant(`since`)))
)
) =>
}

def implicitNotFoundAnnotCheck(msg: String): StructureCheck = {
case Apply(SimpleAnnotCtorNamed("implicitNotFound"), List(Literal(Constant(`msg`)))) =>
}
Expand Down Expand Up @@ -2053,7 +2063,9 @@ class ReadTreeSuite extends RestrictedUnpicklingSuite {
sym
}
assert(clue(deprecatedValSym.annotations).sizeIs == 1)
assert(containsSubtree(deprecatedAnnotNamedCheck("reason", "forever"))(clue(deprecatedValSym.annotations(0).tree)))
assert(
containsSubtree(deprecatedAnnotBothNamedCheck("reason", "forever"))(clue(deprecatedValSym.annotations(0).tree))
)

val myTypeClassSym = findTree(tree) { case ClassDef(TypeName(SimpleName("MyTypeClass")), _, sym) =>
sym
Expand Down
13 changes: 13 additions & 0 deletions tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@ import TestUtils.*
class SignatureSuite extends UnrestrictedUnpicklingSuite:

def assertIsSignedName(actual: Name, simpleName: String, signature: String)(using Location): Unit =
assertIsSignedName(actual, simpleName, signature, simpleName)

def assertIsSignedName(actual: Name, simpleName: String, signature: String, targetName: String)(
using Location
): Unit =
actual match
case name: SignedName =>
assert(clue(name.underlying) == clue(termName(simpleName)))
assert(clue(name.sig.toString) == clue(signature))
assert(clue(name.target) == clue(termName(targetName)))
case _ =>
fail("not a Signed name", clues(actual))
end assertIsSignedName
Expand Down Expand Up @@ -58,6 +64,13 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
assertIsSignedName(identity.signedName, "identity", "(1,java.lang.Object):java.lang.Object")
}

testWithContext("targetName") {
val GenericMethod = ctx.findTopLevelClass("simple_trees.GenericMethod")

val identity = GenericMethod.findNonOverloadedDecl(name"otherIdentity")
assertIsSignedName(identity.signedName, "otherIdentity", "(1,java.lang.Object):java.lang.Object", "otherName")
}

testWithContext("JavaInnerClass") {
val TreeMap = ctx.findTopLevelClass("java.util.TreeMap")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import TestUtils.*
class SymbolSuite extends RestrictedUnpicklingSuite {

/** Needed for correct resolving of ctor signatures */
val fundamentalClasses: Seq[String] = Seq("java.lang.Object", "scala.Unit", "scala.AnyVal")
val fundamentalClasses: Seq[String] =
Seq("java.lang.Object", "scala.Unit", "scala.AnyVal", "scala.annotation.targetName")

def testWithContext(name: String, rootSymbolPath: String, extraRootSymbolPaths: String*)(using munit.Location)(
body: Context ?=> Unit
Expand Down
79 changes: 66 additions & 13 deletions tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tastyquery

import scala.collection.mutable

import tastyquery.Constants.*
import tastyquery.Contexts.*
import tastyquery.Flags.*
import tastyquery.Names.*
Expand Down Expand Up @@ -85,12 +86,11 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
assert(clue(parentClasses) == List(defn.ObjectClass, ProductClass, SerializableClass))
}

def applyOverloadedTest(name: String)(callMethod: String, paramCls: Context ?=> Symbol)(using munit.Location): Unit =
def applyOverloadedTest(name: String)(callMethod: String, checkParamType: Context ?=> Type => Boolean): Unit =
testWithContext(name) {
val OverloadedApplyClass = ctx.findTopLevelClass("simple_trees.OverloadedApply")

val callSym = OverloadedApplyClass.findDecl(termName(callMethod))
val Acls = paramCls

val Some(callTree @ _: DefDef) = callSym.tree: @unchecked

Expand All @@ -102,31 +102,44 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
callCount += 1
assert(app.tpe.isRef(defn.UnitClass), clue(app))
val fooSym = fooRef.tpe.asInstanceOf[TermRef].symbol
val List(Left(List(aRef)), _*) = fooSym.paramRefss: @unchecked
assert(aRef.isRef(Acls), clues(Acls.fullName, aRef))
val mt = fooSym.declaredType.asInstanceOf[MethodType]
assert(clue(mt.resultType).isRef(defn.UnitClass))
assert(checkParamType(clue(mt.paramTypes.head)))
case _ => ()
}

assert(callCount == 1)
}

applyOverloadedTest("apply-overloaded-int")("callA", defn.IntClass)
applyOverloadedTest("apply-overloaded-int")("callA", _.isRef(defn.IntClass))
applyOverloadedTest("apply-overloaded-gen")(
"callB",
ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box")
_.isApplied(
_.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box")),
List(_.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num")))
)
)
applyOverloadedTest("apply-overloaded-nestedObj")(
"callC",
ctx
.findTopLevelClass("simple_trees.OverloadedApply")
.findDecl(moduleClassName("Foo"))
.asClass
.findDecl(termName("Bar"))
_.isRef(
ctx
.findTopLevelClass("simple_trees.OverloadedApply")
.findDecl(moduleClassName("Foo"))
.asClass
.findDecl(termName("Bar"))
)
)
applyOverloadedTest("apply-overloaded-arrayObj")("callD", defn.ArrayClass)
applyOverloadedTest("apply-overloaded-arrayObj")("callD", _.isRef(defn.ArrayClass))
applyOverloadedTest("apply-overloaded-byName")(
"callE",
ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num")
_.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Num"))
)
applyOverloadedTest("apply-overloaded-gen-target-name")(
"callG",
_.isApplied(
_.isRef(ctx.findTopLevelClass("simple_trees.OverloadedApply").findDecl(tname"Box")),
List(_.isRef(defn.IntClass))
)
)

testWithContext("apply-overloaded-not-method") {
Expand Down Expand Up @@ -1530,4 +1543,44 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
assert(innerIdentSym.is(ParamAccessor))
}

testWithContext("annotations") {
val AnnotationsClass = ctx.findTopLevelClass("simple_trees.Annotations")
val inlineClass = ctx.findTopLevelClass("scala.inline")
val deprecatedClass = ctx.findTopLevelClass("scala.deprecated")

locally {
val inlineMethodSym = AnnotationsClass.findNonOverloadedDecl(termName("inlineMethod"))
val List(inlineAnnot) = inlineMethodSym.annotations
assert(clue(inlineAnnot.symbol) == inlineClass)
assert(clue(inlineAnnot.arguments).isEmpty)

assert(inlineMethodSym.hasAnnotation(inlineClass))
assert(!inlineMethodSym.hasAnnotation(deprecatedClass))

assert(inlineMethodSym.getAnnotations(inlineClass) == List(inlineAnnot))
assert(inlineMethodSym.getAnnotations(deprecatedClass) == Nil)

assert(inlineMethodSym.getAnnotation(inlineClass) == Some(inlineAnnot))
assert(inlineMethodSym.getAnnotation(deprecatedClass) == None)
}

locally {
val deprecatedValSym = AnnotationsClass.findNonOverloadedDecl(termName("deprecatedVal"))
val List(deprecatedAnnot) = deprecatedValSym.annotations

assert(clue(deprecatedAnnot.symbol) == deprecatedClass)
assert(clue(deprecatedAnnot.annotConstructor) == deprecatedClass.findNonOverloadedDecl(nme.Constructor))
assert(clue(deprecatedAnnot.argCount) == 2)

deprecatedAnnot.arguments match
case List(Literal(Constant("reason")), Literal(Constant("forever"))) =>
() // OK
case args =>
fail("unexpected arguments", clues(args))

assert(clue(deprecatedAnnot.argIfConstant(0)) == Some(Constant("reason")))
assert(clue(deprecatedAnnot.argIfConstant(1)) == Some(Constant("forever")))
}
}

}
Loading

0 comments on commit af9a13b

Please sign in to comment.