Skip to content

Commit

Permalink
Have explicit boundary/suspend capture sets
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukagami committed Aug 27, 2024
1 parent 80d09d3 commit ba5a2e7
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 26 deletions.
29 changes: 17 additions & 12 deletions jvm/src/main/scala/async/VThreadSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ object VThreadScheduler extends Scheduler:
.name("gears.async.VThread-", 0L)
.factory()

override def execute(body: Runnable^): Unit =
override def execute(body: Runnable): Unit =
val th = VTFactory.newThread(body)
th.start()
()

override def schedule(delay: FiniteDuration, body: Runnable^): Cancellable =
private[gears] inline def unsafeExecute(body: Runnable^): Unit = execute(caps.unsafe.unsafeAssumePure(body))

override def schedule(delay: FiniteDuration, body: Runnable): Cancellable =
import caps.unsafe.unsafeAssumePure

val sr = ScheduledRunnable(delay, body)
// SAFETY: should not be able to access body, only for cancellation
sr.unsafeAssumePure: Cancellable

private final class ScheduledRunnable(delay: FiniteDuration, body: Runnable^) extends Cancellable:
private final class ScheduledRunnable(delay: FiniteDuration, body: Runnable) extends Cancellable:
@volatile var interruptGuard = true // to avoid interrupting the body

val th = VTFactory.newThread: () =>
Expand All @@ -50,7 +52,7 @@ object VThreadScheduler extends Scheduler:
object VThreadSupport extends AsyncSupport:
type Scheduler = VThreadScheduler.type

private final class VThreadLabel[R]():
private final class VThreadLabel[R]() extends caps.Capability:
private var result: Option[R] = None
private val lock = ReentrantLock()
private val cond = lock.newCondition()
Expand All @@ -74,11 +76,11 @@ object VThreadSupport extends AsyncSupport:
result.get
finally lock.unlock()

override opaque type Label[R] = VThreadLabel[R]
override opaque type Label[R, Cap^] <: caps.Capability = VThreadLabel[R]

// outside boundary: waiting on label
// inside boundary: waiting on suspension
private final class VThreadSuspension[-T, +R](using private[VThreadSupport] val l: Label[R] @uncheckedVariance)
private final class VThreadSuspension[-T, +R](using private[VThreadSupport] val l: VThreadLabel[R] @uncheckedVariance)
extends gears.async.Suspension[T, R]:
private var nextInput: Option[T] = None
private val lock = ReentrantLock()
Expand Down Expand Up @@ -107,9 +109,9 @@ object VThreadSupport extends AsyncSupport:

override opaque type Suspension[-T, +R] <: gears.async.Suspension[T, R] = VThreadSuspension[T, R]

override def boundary[R](body: (Label[R]) ?=> R): R =
override def boundary[R, Cap^](body: Label[R, Cap]^ ?->{Cap^} R): R =
val label = VThreadLabel[R]()
VThreadScheduler.execute: () =>
VThreadScheduler.unsafeExecute: () =>
val result = body(using label)
label.setResult(result)

Expand All @@ -119,13 +121,16 @@ object VThreadSupport extends AsyncSupport:
suspension.l.clearResult()
suspension.setInput(arg)

override def scheduleBoundary(body: (Label[Unit]) ?=> Unit)(using Scheduler): Unit =
override def scheduleBoundary[Cap^](body: Label[Unit, Cap] ?-> Unit)(using Scheduler): Unit =
VThreadScheduler.execute: () =>
val label = VThreadLabel[Unit]()
body(using label)

override def suspend[T, R](body: Suspension[T, R] => R)(using l: Label[R]): T =
val sus = new VThreadSuspension[T, R]()
override def suspend[T, R, Cap^](body: Suspension[T, R]^{Cap^} => R^{Cap^})(using l: Label[R, Cap]^): T =
val sus = new VThreadSuspension[T, R](using caps.unsafe.unsafeAssumePure(l))
val res = body(sus)
l.setResult(res)
l.setResult(
// SAFETY: will only be stored and returned by the Suspension resumption mechanism
caps.unsafe.unsafeAssumePure(res)
)
sus.waitInput()
12 changes: 6 additions & 6 deletions shared/src/main/scala/async/AsyncSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ trait Suspension[-T, +R]:
/** Support for suspension capabilities through a delimited continuation interface. */
trait SuspendSupport:
/** A marker for the "limit" of "delimited continuation". */
type Label[R]
type Label[R, Cap^] <: caps.Capability

/** The provided suspension type. */
type Suspension[-T, +R] <: gears.async.Suspension[T, R]

/** Set the suspension marker as the body's caller, and execute `body`. */
def boundary[R](body: Label[R] ?=> R): R
def boundary[R, Cap^](body: Label[R, Cap] ?->{Cap^} R): R^{Cap^}

/** Should return immediately if resume is called from within body */
def suspend[T, R](body: Suspension[T, R] => R)(using Label[R]): T
def suspend[T, R, Cap^](body: Suspension[T, R]^{Cap^} => R^{Cap^})(using Label[R, Cap]): T

/** Extends [[SuspendSupport]] with "asynchronous" boundary/resume functions, in the presence of a [[Scheduler]] */
trait AsyncSupport extends SuspendSupport:
Expand All @@ -34,13 +34,13 @@ trait AsyncSupport extends SuspendSupport:
s.execute(() => suspension.resume(arg))

/** Schedule a computation with the suspension boundary already created. */
private[async] def scheduleBoundary(body: Label[Unit] ?=> Unit)(using s: Scheduler): Unit =
private[async] def scheduleBoundary[Cap^](body: Label[Unit, Cap] ?-> Unit)(using s: Scheduler): Unit =
s.execute(() => boundary(body))

/** A scheduler implementation, with the ability to execute a computation immediately or after a delay. */
trait Scheduler:
def execute(body: Runnable^): Unit
def schedule(delay: FiniteDuration, body: Runnable^): Cancellable
def execute(body: Runnable): Unit
def schedule(delay: FiniteDuration, body: Runnable): Cancellable

object AsyncSupport:
inline def apply()(using ac: AsyncSupport) = ac
24 changes: 16 additions & 8 deletions shared/src/main/scala/async/futures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ object Future:

/** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler`
*/
private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]:
private class RunnableFuture[+T](body: Async.Spawn ?-> T)(using ac: Async) extends CoreFuture[T]:
private given acSupport: ac.support.type = ac.support
private given acScheduler: ac.support.Scheduler = ac.scheduler
/** RunnableFuture maintains its own inner [[CompletionGroup]], that is separated from the provided Async
Expand All @@ -123,14 +123,14 @@ object Future:
private def checkCancellation(): Unit =
if cancelRequest.get() then throw new CancellationException()

private class FutureAsync(val group: CompletionGroup)(using label: acSupport.Label[Unit])
private class FutureAsync[Cap^](val group: CompletionGroup)(using label: acSupport.Label[Unit, Cap])
extends Async(using acSupport, acScheduler):
/** Await a source first by polling it, and, if that fails, by suspending in a onComplete call.
*/
override def await[U](src: Async.Source[U]^): U =
class CancelSuspension extends Cancellable:
var suspension: acSupport.Suspension[Try[U], Unit] = uninitialized
var listener: Listener[U]^{this} = uninitialized
var suspension: acSupport.Suspension[Try[U], Unit]^{Cap^} = uninitialized
var listener: Listener[U]^{this, Cap^} = uninitialized
var completed = false

def complete() = synchronized:
Expand All @@ -142,18 +142,22 @@ object Future:
val completedBefore = complete()
if !completedBefore then
src.dropListener(listener)
acSupport.resumeAsync(suspension)(Failure(new CancellationException()))
// SAFETY: we always await for this suspension to end
val pureSusp = caps.unsafe.unsafeAssumePure(suspension)
acSupport.resumeAsync(pureSusp)(Failure(new CancellationException()))

if group.isCancelled then throw new CancellationException()

src
.poll()
.getOrElse:
val cancellable = CancelSuspension()
val res = acSupport.suspend[Try[U], Unit](k =>
val res = acSupport.suspend[Try[U], Unit, Cap](k =>
val listener = Listener.acceptingListener[U]: (x, _) =>
val completedBefore = cancellable.complete()
if !completedBefore then acSupport.resumeAsync(k)(Success(x))
// SAFETY: Future should already capture Cap^
val purek = caps.unsafe.unsafeAssumePure(k)
if !completedBefore then acSupport.resumeAsync(purek)(Success(x))
cancellable.suspension = k
cancellable.listener = listener
cancellable.link(group) // may resume + remove listener immediately
Expand All @@ -179,13 +183,17 @@ object Future:

end RunnableFuture


/** Create a future that asynchronously executes `body` that wraps its execution in a [[scala.util.Try]]. The returned
* future is linked to the given [[Async.Spawn]] scope by default, i.e. it is cancelled when this scope ends.
*/
def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn)(
using async.type =:= spawnable.type
): Future[T]^{body, spawnable} =
RunnableFuture(body)(using spawnable)
val f = (async: Async.Spawn) => body(using async)
val puref = caps.unsafe.unsafeAssumePure(f)
// SAFETY: body is recorded in the capture set of Future, which should be cancelled when gone out of scope.
RunnableFuture(async ?=> puref(async))(using spawnable)

/** A future that is immediately completed with the given result. */
def now[T](result: Try[T]): Future[T] =
Expand Down

0 comments on commit ba5a2e7

Please sign in to comment.