Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow exports in extension clauses #14497

Merged
merged 7 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,16 @@ object desugar {
vparam.withMods(mods & (GivenOrImplicit | Erased | hasDefault) | Param)
}

def mkApply(fn: Tree, paramss: List[ParamClause])(using Context): Tree =
paramss.foldLeft(fn) { (fn, params) => params match
case TypeDefs(params) =>
TypeApply(fn, params.map(refOfDef))
case (vparam: ValDef) :: _ if vparam.mods.is(Given) =>
Apply(fn, params.map(refOfDef)).setApplyKind(ApplyKind.Using)
case _ =>
Apply(fn, params.map(refOfDef))
}

/** The expansion of a class definition. See inline comments for what is involved */
def classDef(cdef: TypeDef)(using Context): Tree = {
val impl @ Template(constr0, _, self, _) = cdef.rhs
Expand Down Expand Up @@ -588,7 +598,7 @@ object desugar {
}

// new C[Ts](paramss)
lazy val creatorExpr = {
lazy val creatorExpr =
val vparamss = constrVparamss match
case (vparam :: _) :: _ if vparam.mods.is(Implicit) => // add a leading () to match class parameters
Nil :: constrVparamss
Expand All @@ -607,7 +617,6 @@ object desugar {
}
}
ensureApplied(nu)
}

val copiedAccessFlags = if migrateTo3 then EmptyFlags else AccessFlags

Expand Down Expand Up @@ -892,48 +901,50 @@ object desugar {
}
}

def extMethod(mdef: DefDef, extParamss: List[ParamClause])(using Context): DefDef =
cpy.DefDef(mdef)(
name = normalizeName(mdef, mdef.tpt).asTermName,
paramss =
if mdef.name.isRightAssocOperatorName then
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters

paramss match
case params :: paramss1 => // `params` must have a single parameter and without `given` flag

def badRightAssoc(problem: String) =
report.error(i"right-associative extension method $problem", mdef.srcPos)
extParamss ++ mdef.paramss

params match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
// we merge the extension parameters with the method parameters,
// swapping the operator arguments:
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will be encoded as
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
val (leadingUsing, otherExtParamss) = extParamss.span(isUsingOrTypeParamClause)
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
else
badRightAssoc("cannot start with using clause")
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
// no value parameters, so not an infix operator.
extParamss ++ mdef.paramss
else
extParamss ++ mdef.paramss
).withMods(mdef.mods | ExtensionMethod)

/** Transform extension construct to list of extension methods */
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
for mdef <- ext.methods yield
defDef(
cpy.DefDef(mdef)(
name = normalizeName(mdef, ext).asTermName,
paramss =
if mdef.name.isRightAssocOperatorName then
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters

paramss match
case params :: paramss1 => // `params` must have a single parameter and without `given` flag

def badRightAssoc(problem: String) =
report.error(i"right-associative extension method $problem", mdef.srcPos)
ext.paramss ++ mdef.paramss

params match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
// we merge the extension parameters with the method parameters,
// swapping the operator arguments:
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will be encoded as
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
val (leadingUsing, otherExtParamss) = ext.paramss.span(isUsingOrTypeParamClause)
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
else
badRightAssoc("cannot start with using clause")
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
// no value parameters, so not an infix operator.
ext.paramss ++ mdef.paramss
else
ext.paramss ++ mdef.paramss
).withMods(mdef.mods | ExtensionMethod)
)
ext.methods map {
case exp: Export => exp
case mdef: DefDef => defDef(extMethod(mdef, ext.paramss))
}
}

/** Transforms
*
* <mods> type t >: Low <: Hi
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
case class ExtMethods(paramss: List[ParamClause], methods: List[DefDef])(implicit @constructorOnly src: SourceFile) extends Tree
case class ExtMethods(paramss: List[ParamClause], methods: List[Tree])(implicit @constructorOnly src: SourceFile) extends Tree
case class MacroTree(expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree

case class ImportSelector(imported: Ident, renamed: Tree = EmptyTree, bound: Tree = EmptyTree)(implicit @constructorOnly src: SourceFile) extends Tree {
Expand Down Expand Up @@ -640,7 +640,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case tree: PatDef if (mods eq tree.mods) && (pats eq tree.pats) && (tpt eq tree.tpt) && (rhs eq tree.rhs) => tree
case _ => finalize(tree, untpd.PatDef(mods, pats, tpt, rhs)(tree.source))
}
def ExtMethods(tree: Tree)(paramss: List[ParamClause], methods: List[DefDef])(using Context): Tree = tree match
def ExtMethods(tree: Tree)(paramss: List[ParamClause], methods: List[Tree])(using Context): Tree = tree match
case tree: ExtMethods if (paramss eq tree.paramss) && (methods == tree.methods) => tree
case _ => finalize(tree, untpd.ExtMethods(paramss, methods)(tree.source))
def ImportSelector(tree: Tree)(imported: Ident, renamed: Tree, bound: Tree)(using Context): Tree = tree match {
Expand Down
45 changes: 28 additions & 17 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3123,7 +3123,7 @@ object Parsers {
/** Import ::= `import' ImportExpr {‘,’ ImportExpr}
* Export ::= `export' ImportExpr {‘,’ ImportExpr}
*/
def importClause(leading: Token, mkTree: ImportConstr): List[Tree] = {
def importOrExportClause(leading: Token, mkTree: ImportConstr): List[Tree] = {
val offset = accept(leading)
commaSeparated(importExpr(mkTree)) match {
case t :: rest =>
Expand All @@ -3136,6 +3136,12 @@ object Parsers {
}
}

def exportClause() =
importOrExportClause(EXPORT, Export(_,_))

def importClause(outermost: Boolean = false) =
importOrExportClause(IMPORT, mkImport(outermost))

/** Create an import node and handle source version imports */
def mkImport(outermost: Boolean = false): ImportConstr = (tree, selectors) =>
val imp = Import(tree, selectors)
Expand Down Expand Up @@ -3685,8 +3691,10 @@ object Parsers {
if in.isColon() then
syntaxError("no `:` expected here")
in.nextToken()
val methods =
if isDefIntro(modifierTokens) then
val methods: List[Tree] =
if in.token == EXPORT then
exportClause()
else if isDefIntro(modifierTokens) then
extMethod(nparams) :: Nil
else
in.observeIndented()
Expand All @@ -3696,12 +3704,13 @@ object Parsers {
val result = atSpan(start)(ExtMethods(joinParams(tparams, leadParamss.toList), methods))
val comment = in.getDocComment(start)
if comment.isDefined then
for meth <- methods do
for case meth: DefDef <- methods do
if !meth.rawComment.isDefined then meth.setComment(comment)
result
end extension

/** ExtMethod ::= {Annotation [nl]} {Modifier} ‘def’ DefDef
* | Export
*/
def extMethod(numLeadParams: Int): DefDef =
val start = in.offset
Expand All @@ -3711,16 +3720,18 @@ object Parsers {

/** ExtMethods ::= ExtMethod | [nl] ‘{’ ExtMethod {semi ExtMethod ‘}’
*/
def extMethods(numLeadParams: Int): List[DefDef] = checkNoEscapingPlaceholders {
val meths = new ListBuffer[DefDef]
def extMethods(numLeadParams: Int): List[Tree] = checkNoEscapingPlaceholders {
val meths = new ListBuffer[Tree]
while
val start = in.offset
val mods = defAnnotsMods(modifierTokens)
in.token != EOF && {
accept(DEF)
meths += defDefOrDcl(start, mods, numLeadParams)
in.token != EOF && statSepOrEnd(meths, what = "extension method")
}
if in.token == EXPORT then
meths ++= exportClause()
else
val mods = defAnnotsMods(modifierTokens)
if in.token != EOF then
accept(DEF)
meths += defDefOrDcl(start, mods, numLeadParams)
in.token != EOF && statSepOrEnd(meths, what = "extension method")
do ()
if meths.isEmpty then syntaxErrorOrIncomplete("`def` expected")
meths.toList
Expand Down Expand Up @@ -3868,9 +3879,9 @@ object Parsers {
else stats += packaging(start)
}
else if (in.token == IMPORT)
stats ++= importClause(IMPORT, mkImport(outermost))
stats ++= importClause(outermost)
else if (in.token == EXPORT)
stats ++= importClause(EXPORT, Export(_,_))
stats ++= exportClause()
else if isIdent(nme.extension) && followingIsExtension() then
stats += extension()
else if isDefIntro(modifierTokens) then
Expand Down Expand Up @@ -3916,9 +3927,9 @@ object Parsers {
while
var empty = false
if (in.token == IMPORT)
stats ++= importClause(IMPORT, mkImport())
stats ++= importClause()
else if (in.token == EXPORT)
stats ++= importClause(EXPORT, Export(_,_))
stats ++= exportClause()
else if isIdent(nme.extension) && followingIsExtension() then
stats += extension()
else if (isDefIntro(modifierTokensOrCase))
Expand Down Expand Up @@ -3994,7 +4005,7 @@ object Parsers {
while
var empty = false
if (in.token == IMPORT)
stats ++= importClause(IMPORT, mkImport())
stats ++= importClause()
else if (isExprIntro)
stats += expr(Location.InBlock)
else if in.token == IMPLICIT && !in.inModifierPosition() then
Expand Down
Loading