Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let AsynchronousIo support any domains that support Lift[Unit, ?] #580

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 45 additions & 65 deletions Dsl/src/main/scala/com/thoughtworks/dsl/Dsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,43 +67,20 @@ private[dsl] trait LowPriorityDsl1 { this: Dsl.type =>

private[dsl] trait LowPriorityDsl0 extends LowPriorityDsl1 { this: Dsl.type =>

// // FIXME: Shift
// implicit def continuationDsl[Keyword, LeftDomain, RightDomain, Value](
// implicit restDsl: Dsl.Original[Keyword, LeftDomain, Value],
// shiftDsl2: Dsl.Original[Shift[LeftDomain, RightDomain], LeftDomain, RightDomain]
// ): Dsl.Original[Keyword, LeftDomain !! RightDomain, Value] = {
// new Dsl.Original[Keyword, LeftDomain !! RightDomain, Value] {
// def cpsApply(keyword: Keyword, handler: Value => LeftDomain !! RightDomain): LeftDomain !! RightDomain = {
// (continue: RightDomain => LeftDomain) =>
// restDsl.cpsApply(keyword, { a =>
// restDsl2.cpsApply(handler(a), continue)
// })
// }
// }
// }

private def throwableContinuationDsl[Keyword, LeftDomain, Value](implicit
restDsl: Dsl.Searching[Keyword, LeftDomain, Value]
): Dsl[Keyword, LeftDomain !! Throwable, Value] = Dsl {
(keyword, handler) => continue =>
(keyword, handler) => (continue: Throwable => LeftDomain) =>
restDsl(
keyword,
// Use `new` to support the `return`
new (Value => LeftDomain) {
def apply(value: Value): LeftDomain = {
val protectedContinuation =
try {
handler(value)
} catch {
case NonFatal(e) =>
return continue(e)
}
// FIXME: Shift[Domain, Throwable]
protectedContinuation(continue)
}
{ value =>
TrampolineContinuation { () =>
handler(value)
}(continue)
}
)
}

given [Keyword, LeftDomain, Value](using
Dsl.IsStackSafe[LeftDomain],
Dsl.Searching[Keyword, LeftDomain, Value]
Expand All @@ -125,6 +102,45 @@ object Dsl extends LowPriorityDsl0 {
) => Domain
) =:= Dsl[Keyword, Domain, Value] =
summon
private[dsl] abstract class TrampolineFunction1[-A, +R] extends (A => R) {
protected def step(): A => R
@tailrec
protected final def last(): A => R = {
step() match {
case trampoline: TrampolineFunction1[A, R] =>
trampoline.last()
case notTrampoline =>
notTrampoline
}
}

def apply(state: A): R = {
last()(state)
}

}
object TrampolineFunction1 {
def apply[A, R](trampoline: TrampolineFunction1[A, R]) = trampoline
}

private[dsl] abstract class TrampolineContinuation[LeftDomain]
extends TrampolineFunction1[Throwable => LeftDomain, LeftDomain] {

override final def apply(handler: Throwable => LeftDomain): LeftDomain = {
val protectedContinuation: LeftDomain !! Throwable =
try {
last()
} catch {
case NonFatal(e) =>
return handler(e)
}
protectedContinuation(handler)
}
}
private[dsl] object TrampolineContinuation {
def apply[LeftDomain](continuation: TrampolineContinuation[LeftDomain]) =
continuation
}

trait IsStackSafe[Domain]
object IsStackSafe extends IsStackSafe.LowPriority0:
Expand Down Expand Up @@ -359,42 +375,6 @@ object Dsl extends LowPriorityDsl0 {
): Dsl.Derived.StackUnsafe[Keyword, TailRec[Domain], Value] =
Dsl.Derived.StackUnsafe(derivedTailRecDsl)

private def derivedThrowableTailRecDsl[Keyword, LeftDomain, Value](implicit
restDsl: Dsl.Searching[Keyword, LeftDomain !! Throwable, Value]
): Dsl[Keyword, TailRec[LeftDomain] !! Throwable, Value] =
Dsl {
(
keyword: Keyword,
handler: (Value => TailRec[LeftDomain] !! Throwable)
) => (tailRecFailureHandler: Throwable => TailRec[LeftDomain]) =>
TailCalls.done(
restDsl(
keyword,
{ value => failureHandler =>
handler(value) { e =>
TailCalls.done(failureHandler(e))
}.result
}
) { e =>
tailRecFailureHandler(e).result
}
)
}
given [Keyword, LeftDomain, TailRecValue](using
Dsl.IsStackSafe[LeftDomain],
Dsl.Searching[Keyword, LeftDomain !! Throwable, TailRecValue]
): Dsl.Derived.StackSafe[Keyword, TailRec[
LeftDomain
] !! Throwable, TailRecValue] =
Dsl.Derived.StackSafe(derivedThrowableTailRecDsl)
given [Keyword, LeftDomain, TailRecValue](using
util.NotGiven[Dsl.IsStackSafe[LeftDomain]],
Dsl.Searching[Keyword, LeftDomain !! Throwable, TailRecValue]
): Dsl.Derived.StackUnsafe[Keyword, TailRec[
LeftDomain
] !! Throwable, TailRecValue] =
Dsl.Derived.StackUnsafe(derivedThrowableTailRecDsl)

private[dsl] type !![R, +A] = (A => R) => R

@FunctionalInterface
Expand Down
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ lazy val `keywords-AsynchronousIo` =
crossProject(JSPlatform, JVMPlatform)
.crossType(CrossType.Pure)
.dependsOn(
`keywords-Shift`,
`keywords-Fence`,
`keywords-Shift` % Test,
`keywords-Each` % Test,
`keywords-Using` % Test,
`domains-Task` % Test
Expand Down Expand Up @@ -151,6 +152,7 @@ lazy val `keywords-Await` =
.dependsOn(
Dsl,
`domains-Continuation`,
`domains-Fence`,
`macros-Reset` % Test,
`domains-Task` % Test,
`keywords-Get` % Test,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,29 +147,32 @@ object AsynchronousIo {
}
}

private def completionHandler[Value](successHandler: Value => (Unit !! Throwable)) = {
new CompletionHandler[Value, Throwable => Unit] {
def failed(exc: Throwable, failureHandler: Throwable => Unit): Unit = {
failureHandler(exc)
}
given [Domain, Value](using
fenceDsl: Dsl.Searching[Fence.type, Domain, Unit],
liftUnit: Dsl.Lift[Unit, Domain]
): Dsl.Original[AsynchronousIo[Value], Domain, Value] =
Dsl.Original { (keyword, handler) =>
liftUnit(
keyword.start(
handler,
new CompletionHandler[Value, Value => Domain] {
def failed(
exception: Throwable,
handler: Value => Domain
): Unit = {
fenceDsl(Fence, (_: Unit) => throw exception)
}

def completed(result: Value, failureHandler: Throwable => Unit): Unit = {
val protectedContinuation =
try {
successHandler(result)
} catch {
case NonFatal(e) =>
val () = failureHandler(e)
return
def completed(
result: Value,
handler: Value => Domain
): Unit = {
fenceDsl(Fence, (_: Unit) => handler(result))
}
}
protectedContinuation(failureHandler)
}
}
}
)
)

given [Value]: Dsl.Original[AsynchronousIo[Value], Unit !! Throwable, Value] =
Dsl.Original { (keyword, handler) =>
keyword.start(_, completionHandler(handler))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scala.concurrent.Await.result
import scala.concurrent.duration.Duration
import scala.concurrent.{ExecutionContext, Future}
import scala.language.implicitConversions
import scala.util.*

/** [[Await]] is a [[Dsl.Keyword Keyword]] to extract value from a
* [[scala.concurrent.Future]].
Expand Down Expand Up @@ -113,27 +114,41 @@ import scala.language.implicitConversions
*/
opaque type Await[+AwaitableValue] <: Dsl.Keyword.Opaque =
Dsl.Keyword.Opaque.Of[AwaitableValue]
object Await extends AwaitJS {
object Await extends AwaitJS with Await.LowPriority0 {
@inline def apply[AwaitableValue]: AwaitableValue =:= Await[AwaitableValue] =
Dsl.Keyword.Opaque.Of.apply
given [FutureResult]: IsKeyword[Await[Future[FutureResult]], FutureResult] with {}
given [FutureResult]: IsKeyword[Await[Future[FutureResult]], FutureResult]
with {}

given [FutureResult, That](using
ExecutionContext
): Dsl.Original[Await[Future[FutureResult]], Future[That], FutureResult] = Dsl.Original(
_ flatMap _)
): Dsl.Original[Await[Future[FutureResult]], Future[That], FutureResult] =
Dsl.Original(_ flatMap _)

// // TODO:
// implicit def tailRecContinuationAwaitDsl[Value](implicit
// executionContext: ExecutionContext
// ): Dsl.Original[Await[Value], TailRec[Unit] !! Throwable, Value]
private[Await] trait LowPriority0:
inline given [Domain, Value](using
fenceDsl: Dsl.Searching[Fence.type, Domain, Unit],
liftUnit: Dsl.Lift[Unit, Domain],
executionContext: ExecutionContext,
): Dsl.Composed[Await[Future[Value]], Domain, Value] = Dsl.Composed {
println(compiletime.codeOf(
{(x: Domain) => ()}
))

(keyword: Await[Future[Value]], handler: Value => Domain) =>
keyword.onComplete {
case Failure(e) =>
fenceDsl(Fence, (_: Unit) => throw e)
case Success(result) =>
fenceDsl(Fence, (_: Unit) => handler(result))
}
liftUnit(())
}

given [Value](using
ExecutionContext
): Dsl.Original[Await[Future[Value]], Unit !! Throwable, Value] = Dsl.Original {
(keyword: Await[Future[Value]], handler: Value => Unit !! Throwable) =>
!!.fromTryContinuation[Unit, Value](keyword.onComplete)(handler)
}
extension [FA, A](inline fa: FA)(using
inline notKeyword: util.NotGiven[
FA <:< Dsl.Keyword
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,37 +142,11 @@ object Shift extends LowPriorityShift0 {

}

private abstract class TrampolineContinuation[LeftDomain]
extends (LeftDomain !! Throwable) {
protected def step(): LeftDomain !! Throwable

@tailrec
private final def last(): LeftDomain !! Throwable = {
step() match {
case trampoline: TrampolineContinuation[LeftDomain] =>
trampoline.last()
case notTrampoline =>
notTrampoline
}
}

final def apply(handler: Throwable => LeftDomain): LeftDomain = {
val protectedContinuation: LeftDomain !! Throwable =
try {
last()
} catch {
case NonFatal(e) =>
return handler(e)
}
protectedContinuation(handler)
}
}

private def suspend[LeftDomain, Value](
continuation: LeftDomain !! Throwable !! Value,
handler: Value => LeftDomain !! Throwable
): TrampolineContinuation[LeftDomain] =
new TrampolineContinuation[LeftDomain] {
): Dsl.TrampolineContinuation[LeftDomain] =
new Dsl.TrampolineContinuation[LeftDomain] {
protected def step() = continuation(handler)
}

Expand All @@ -192,8 +166,8 @@ object Shift extends LowPriorityShift0 {
handler: Value => LeftDomain !! Throwable !! RightDomain,
value: Value,
continue: RightDomain => LeftDomain !! Throwable
): TrampolineContinuation[LeftDomain] =
new TrampolineContinuation[LeftDomain] {
): Dsl.TrampolineContinuation[LeftDomain] =
new Dsl.TrampolineContinuation[LeftDomain] {
protected def step() = {
handler(value)(continue)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.thoughtworks.dsl
package keywords
import Dsl.IsKeyword
import Dsl.!!
import Dsl.cpsApply
import scala.concurrent.Future
import scala.concurrent.ExecutionContext
Expand All @@ -17,19 +18,26 @@ object Suspend extends Suspend.LowPriority0 {
upstreamIsKeyword: => IsKeyword[Upstream, UpstreamValue]
): IsKeyword[Suspend[Upstream], UpstreamValue] with {}

private[Suspend] trait LowPriority0:
private[Suspend] trait LowPriority1:
given [Keyword, Domain, Value](using
Dsl.Searching[Keyword, Domain, Value]
): Dsl.Composed[Suspend[Keyword], Domain, Value] = Dsl.Composed {
(keyword: Suspend[Keyword], handler: Value => Domain) =>
keyword().cpsApply(handler)
}
private[Suspend] trait LowPriority0 extends LowPriority1:
given [Keyword, State, Domain, Value](using
Dsl.Searching[Keyword, State => Domain, Value]
): Dsl.Composed[Suspend[Keyword], State => Domain, Value] = Dsl.Composed {
(keyword: Suspend[Keyword], handler: Value => State => Domain) =>
Dsl.TrampolineFunction1(() => keyword().cpsApply(handler))
}

given [Keyword, State, Domain, Value](using
Dsl.Searching[Keyword, State => Domain, Value]
): Dsl.Composed[Suspend[Keyword], State => Domain, Value] = Dsl.Composed {
(keyword: Suspend[Keyword], handler: Value => State => Domain) => value =>
keyword().cpsApply(handler)(value)
given [Keyword, Domain, Value](using
Dsl.Searching[Keyword, Domain !! Throwable, Value]
): Dsl.Composed[Suspend[Keyword], Domain !! Throwable, Value] = Dsl.Composed {
(keyword: Suspend[Keyword], handler: Value => (Domain !! Throwable)) =>
Dsl.TrampolineContinuation(() => keyword().cpsApply(handler))
}

given [Keyword, Result, Value](using
Expand Down
Loading