Skip to content

Commit

Permalink
Minimize changes in futures
Browse files Browse the repository at this point in the history
  • Loading branch information
natsukagami committed Jul 12, 2024
1 parent 5830cb2 commit 9ce56ce
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 90 deletions.
2 changes: 1 addition & 1 deletion shared/src/main/scala/async/AsyncOperations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object AsyncOperations:
* [[java.util.concurrent.TimeoutException]] is thrown.
*/
def withTimeout[T](timeout: FiniteDuration)(op: Async ?=> T)(using AsyncOperations, Async): T =
Async.group:
Async.group: spawn ?=>
Async.select(
Future(op).handle(_.get),
Future(sleep(timeout)).handle: _ =>
Expand Down
175 changes: 86 additions & 89 deletions shared/src/main/scala/async/futures.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package gears.async

import language.experimental.captureChecking

import java.util.concurrent.CancellationException
import java.util.concurrent.atomic.AtomicBoolean
import scala.annotation.tailrec
import scala.annotation.unchecked.uncheckedCaptures
import scala.annotation.unchecked.uncheckedVariance
import scala.collection.mutable
import scala.compiletime.uninitialized
import scala.util
import scala.util.control.NonFatal
import scala.util.{Failure, Success, Try}

import language.experimental.captureChecking
import gears.async.Async.SourceSymbol

/** Futures are [[Async.Source Source]]s that has the following properties:
* - They represent a single value: Once resolved, [[Async.await await]]-ing on a [[Future]] should always return the
Expand Down Expand Up @@ -51,10 +51,11 @@ object Future:
* - withResolver: Completion is done by external request set up from a block of code.
*/
private class CoreFuture[+T] extends Future[T]:

@volatile protected var hasCompleted: Boolean = false
protected var cancelRequest = AtomicBoolean(false)
private var result: Try[T] = uninitialized // guaranteed to be set if hasCompleted = true
private val waiting = mutable.Set[(Listener[Try[T]]^) @uncheckedCaptures]()
private val waiting: mutable.Set[Listener[Try[T]]^] = mutable.Set()

// Async.Source method implementations

Expand Down Expand Up @@ -107,49 +108,11 @@ object Future:

end CoreFuture

private class CancelSuspension[U](val src: Async.Source[U]^)(val ac: Async, val suspension: ac.support.Suspension[Try[U], Unit]) extends Cancellable:
self: CancelSuspension[U]^{src, ac} =>
val listener: Listener[U]^{ac} = Listener.acceptingListener[U]: (x, _) =>
val completedBefore = complete()
if !completedBefore then
ac.support.resumeAsync(suspension)(Success(x))
unlink()
var completed = false

def complete() = synchronized:
val completedBefore = completed
completed = true
completedBefore

override def cancel() =
val completedBefore = complete()
if !completedBefore then
src.dropListener(listener)
ac.support.resumeAsync(suspension)(Failure(new CancellationException()))

private class FutureAsync(val group: CompletionGroup)(using ac: Async, label: ac.support.Label[Unit]) extends Async(using ac.support):
/** Await a source first by polling it, and, if that fails, by suspending in a onComplete call.
*/
override def await[U](src: Async.Source[U]^): U =
if group.isCancelled then throw new CancellationException()
src
.poll()
.getOrElse:
val res = ac.support.suspend[Try[U], Unit](k =>
val cancellable: CancelSuspension[U]^{src, ac} = CancelSuspension(src)(ac, k)
// val listener: Listener[U] = Listener.acceptingListener[U]: (x, _) => ???
// val completedBefore = cancellable.complete()
// if !completedBefore then ac.support.resumeAsync(k)(Success(x))
cancellable.link(group) // may resume + remove listener immediately
src.onComplete(cancellable.listener)
)(using label)
res.get

override def withGroup(group: CompletionGroup): Async = FutureAsync(group)

/** A future that is completed by evaluating `body` as a separate asynchronous operation in the given `scheduler`
*/
private class RunnableFuture[+T](body: Async.Spawn ?=> T)(using ac: Async) extends CoreFuture[T]:
private given acSupport: ac.support.type = ac.support
private given acScheduler: ac.support.Scheduler = ac.scheduler
/** RunnableFuture maintains its own inner [[CompletionGroup]], that is separated from the provided Async
* instance's. When the future is cancelled, we only cancel this CompletionGroup. This effectively means any
* `.await` operations within the future is cancelled *only if they link into this group*. The future body run with
Expand All @@ -160,6 +123,47 @@ object Future:
private def checkCancellation(): Unit =
if cancelRequest.get() then throw new CancellationException()

private class FutureAsync(val group: CompletionGroup)(using label: acSupport.Label[Unit])
extends Async(using acSupport, acScheduler):
/** Await a source first by polling it, and, if that fails, by suspending in a onComplete call.
*/
override def await[U](src: Async.Source[U]^): U =
class CancelSuspension extends Cancellable:
var suspension: acSupport.Suspension[Try[U], Unit] = uninitialized
var listener: Listener[U]^{this} = uninitialized
var completed = false

def complete() = synchronized:
val completedBefore = completed
completed = true
completedBefore

override def cancel() =
val completedBefore = complete()
if !completedBefore then
src.dropListener(listener)
acSupport.resumeAsync(suspension)(Failure(new CancellationException()))

if group.isCancelled then throw new CancellationException()

src
.poll()
.getOrElse:
val cancellable = CancelSuspension()
val res = acSupport.suspend[Try[U], Unit](k =>
val listener = Listener.acceptingListener[U]: (x, _) =>
val completedBefore = cancellable.complete()
if !completedBefore then acSupport.resumeAsync(k)(Success(x))
cancellable.suspension = k
cancellable.listener = listener
cancellable.link(group) // may resume + remove listener immediately
src.onComplete(listener)
)
cancellable.unlink()
res.get

override def withGroup(group: CompletionGroup): Async = FutureAsync(group)

override def cancel(): Unit = if setCancelled() then this.innerGroup.cancel()

link()
Expand All @@ -178,8 +182,9 @@ object Future:
/** Create a future that asynchronously executes `body` that wraps its execution in a [[scala.util.Try]]. The returned
* future is linked to the given [[Async.Spawn]] scope by default, i.e. it is cancelled when this scope ends.
*/
def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn)
(using async.type =:= spawnable.type): Future[T]^{body, spawnable} =
def apply[T](body: Async.Spawn ?=> T)(using async: Async, spawnable: Async.Spawn)(
using async.type =:= spawnable.type
): Future[T]^{body, spawnable} =
RunnableFuture(body)(using spawnable)

/** A future that is immediately completed with the given result. */
Expand All @@ -197,11 +202,11 @@ object Future:
/** A future that immediately rejects with the given exception. Similar to `Future.now(Failure(exception))`. */
inline def rejected(exception: Throwable): Future[Nothing] = now(Failure(exception))

extension [T](f1: Future[T]^)
extension [T](f1: Future[T])
/** Parallel composition of two futures. If both futures succeed, succeed with their values in a pair. Otherwise,
* fail with the failure that was returned first.
*/
def zip[U](f2: Future[U]^): Future[(T, U)]^{f1, f2} =
def zip[U](f2: Future[U]): Future[(T, U)] =
Future.withResolver: r =>
Async
.either(f1, f2)
Expand Down Expand Up @@ -234,20 +239,20 @@ object Future:
* @see
* [[orWithCancel]] for an alternative version where the slower future is cancelled.
*/
def or(f2: Future[T]^): Future[T]^{f1, f2} = orImpl(false)(f2)
def or(f2: Future[T]): Future[T] = orImpl(false)(f2)

/** Like `or` but the slower future is cancelled. If either task succeeds, succeed with the success that was
* returned first and the other is cancelled. Otherwise, fail with the failure that was returned last.
*/
def orWithCancel(f2: Future[T]^): Future[T]^{f1, f2} = orImpl(true)(f2)
def orWithCancel(f2: Future[T]): Future[T] = orImpl(true)(f2)

inline def orImpl(inline withCancel: Boolean)(f2: Future[T]^): Future[T]^{f1, f2} = Future.withResolver: r =>
inline def orImpl(inline withCancel: Boolean)(f2: Future[T]): Future[T] = Future.withResolver: r =>
Async
.raceWithOrigin(f1, f2)
.onComplete(Listener { case ((v, which), _) =>
v match
case Success(value) =>
inline if withCancel then (if which == f1.symbol then f2 else f1).cancel()
inline if withCancel then (if which == f1 then f2 else f1).cancel()
r.resolve(value)
case Failure(_) =>
(if which == f1.symbol then f2 else f1).onComplete(Listener((v, _) => r.complete(v)))
Expand Down Expand Up @@ -300,7 +305,7 @@ object Future:
* may be used. The handler should eventually complete the Future using one of complete/resolve/reject*. The
* default handler is set up to [[rejectAsCancelled]] immediately.
*/
def onCancel(handler: () -> Unit): Unit
def onCancel(handler: () => Unit): Unit
end Resolver

/** Create a promise that may be completed asynchronously using external means.
Expand All @@ -310,16 +315,16 @@ object Future:
*
* If the external operation supports cancellation, the body can register one handler using [[Resolver.onCancel]].
*/
def withResolver[T](body: Resolver[T]^ => Unit): Future[T] =
val future = new CoreFuture[T] with Resolver[T] with Promise[T] {
@volatile var cancelHandle: (() -> Unit) = () => rejectAsCancelled()
override def onCancel(handler: () -> Unit): Unit = cancelHandle = handler
def withResolver[T](body: Resolver[T] => Unit): Future[T] =
val future = new CoreFuture[T] with Resolver[T] with Promise[T]:
@volatile var cancelHandle: () -> Unit = () => rejectAsCancelled()
override def onCancel(handler: () => Unit): Unit = cancelHandle = caps.unsafe.unsafeAssumePure(handler)
override def complete(result: Try[T]): Unit = super.complete(result)

override def cancel(): Unit =
if setCancelled() then cancelHandle()
}
body(future)
end future
body(future: Resolver[T])
future
end withResolver

Expand All @@ -338,51 +343,46 @@ object Future:
* [[Future.awaitAll]] and [[Future.awaitFirst]] for simple usage of the collectors to get all results or the first
* succeeding one.
*/
class Collector[T](val futures: (Future[T]^)*):
class Collector[T](futures: (Future[T]^)*):
private val ch = UnboundedChannel[Future[T]^{futures*}]()

private val futureRefs = mutable.Map[Async.SourceSymbol[Try[T]], Future[T]^{futures*}]()
private val futMap = mutable.Map[SourceSymbol[Try[T]], Future[T]^{futures*}]()

/** Output channels of all finished futures. */
final def results: ReadableChannel[Future[T]^{futures*}] = ch.asReadable

private val listener = Listener((_, futRef) =>
private val listener = Listener((_, fut) =>
// safe, as we only attach this listener to Future[T]
val ref = futRef.asInstanceOf[Async.SourceSymbol[Try[T]]]
val fut = futureRefs.synchronized:
// futureRefs.remove(ref).get
futureRefs(ref)
ch.sendImmediately(futureRefs(fut.symbol))
val future = futMap.synchronized:
futMap.remove(fut.asInstanceOf[SourceSymbol[Try[T]]]).get
ch.sendImmediately(future)
)

protected final def addFuture(future: Future[T]^{futures*}) =
futureRefs.synchronized:
futureRefs += (future.symbol -> future)
futMap.synchronized { futMap += (future.symbol -> future) }
future.onComplete(listener)

futures.foreach(addFuture)
end Collector

/** Like [[Collector]], but exposes the ability to add futures after creation. */
class MutableCollector[T](futures: (Future[T]^)*) extends Collector[T](futures*):
class MutableCollector[T](futures: Future[T]*) extends Collector[T](futures*):
/** Add a new [[Future]] into the collector. */
def add(future: Future[T]^{futures*}): Unit = addFuture(future)
def +=(future: Future[T]^{futures*}) = add(future)
inline def add(future: Future[T]^) = addFuture(future)
inline def +=(future: Future[T]^) = add(future)

extension [T](@caps.unbox fs: Seq[Future[T]^])
extension [T](fs: Seq[Future[T]])
/** `.await` for all futures in the sequence, returns the results in a sequence, or throws if any futures fail. */
def awaitAll(using Async) =
val collector = Collector(fs*)
for _ <- fs do
val fut: Future[T]^{fs*} = collector.results.read().right.get
fut.await
for _ <- fs do collector.results.read().right.get.await
fs.map(_.await)

/** Like [[awaitAll]], but cancels all futures as soon as one of them fails. */
def awaitAllOrCancel(using Async) =
val collector = Collector[T](fs*)
val collector = Collector(fs*)
try
for _ <- fs do ??? // collector.results.read().right.get.await
for _ <- fs do collector.results.read().right.get.await
fs.map(_.await)
catch
case NonFatal(e) =>
Expand All @@ -391,22 +391,20 @@ object Future:

/** Race all futures, returning the first successful value. Throws the last exception received, if everything fails.
*/
def awaitFirst(using Async): T = impl.awaitFirstImpl[T](fs, false)
def awaitFirst(using Async): T = awaitFirstImpl(false)

/** Like [[awaitFirst]], but cancels all other futures as soon as the first future succeeds. */
def awaitFirstWithCancel(using Async): T = impl.awaitFirstImpl[T](fs, true)
def awaitFirstWithCancel(using Async): T = awaitFirstImpl(true)

private object impl:
def awaitFirstImpl[T](@caps.unbox fs: Seq[Future[T]^], withCancel: Boolean)(using Async): T =
val collector = Collector[T](fs*)
private inline def awaitFirstImpl(withCancel: Boolean)(using Async): T =
val collector = Collector(fs*)
@scala.annotation.tailrec
def loop(attempt: Int): T =
val fut: Future[T]^{fs*} = collector.results.read().right.get
fut.awaitResult match
collector.results.read().right.get.awaitResult match
case Failure(exception) =>
if attempt == fs.length then /* everything failed */ throw exception else loop(attempt + 1)
case Success(value) =>
if withCancel then fs.foreach(_.cancel())
inline if withCancel then fs.foreach(_.cancel())
value
loop(1)
end Future
Expand All @@ -432,11 +430,10 @@ class Task[+T](val body: (Async, AsyncOperations) ?=> T):
def run()(using Async, AsyncOperations): T = body

/** Start a future computed from the `body` of this task */
def start()(using async: Async, spawn: Async.Spawn, asyncOps: AsyncOperations)
(using async.type =:= spawn.type): Future[T]^{this, spawn} =
def start()(using async: Async, spawn: Async.Spawn)(using asyncOps: AsyncOperations)(using async.type =:= spawn.type): Future[T]^{body, spawn} =
Future(body)(using async, spawn)

def schedule(s: TaskSchedule): Task[T]^{this} =
def schedule(s: TaskSchedule): Task[T]^{body} =
s match {
case TaskSchedule.Every(millis, maxRepetitions) =>
assert(millis >= 1)
Expand Down

0 comments on commit 9ce56ce

Please sign in to comment.