Skip to content

Commit

Permalink
improvement: cancel current request in batched functions (#5432)
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek authored Oct 2, 2023
1 parent dce10fd commit 60a1d16
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package scala.meta.internal.metals

import java.util.concurrent.CancellationException
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.util.Failure
import scala.util.Success
import scala.util.Try
import scala.util.control.NonFatal

import scala.meta.internal.async.ConcurrentQueue
Expand All @@ -22,6 +25,7 @@ final class BatchedFunction[A, B](
fn: Seq[A] => CancelableFuture[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit ec: ExecutionContext)
extends (Seq[A] => Future[B])
with Function2[Seq[A], () => Unit, Future[B]]
Expand Down Expand Up @@ -75,8 +79,17 @@ final class BatchedFunction[A, B](
}

def cancelAll(): Unit = {
queue.clear()
unlock()
val requests = ConcurrentQueue.pollAll(queue)
requests.foreach(_.result.complete(defaultResult))
cancelCurrent()
}

def cancelCurrent(): Unit = {
lock.get() match {
case None =>
case Some(promise) =>
promise.tryFailure(new BatchedFunction.BatchedFunctionCancelation)
}
}

def currentFuture(): Future[B] = {
Expand All @@ -97,22 +110,28 @@ final class BatchedFunction[A, B](
callback: () => Unit,
)

private val lock = new AtomicBoolean()
private val lock = new AtomicReference[Option[Promise[B]]](None)

private def unlock(): Unit = {
lock.set(false)
lock.set(None)
if (!queue.isEmpty) {
runAcquire()
}
}
private def runAcquire(): Unit = {
if (!isPaused.get() && lock.compareAndSet(false, true)) {
runRelease()
lazy val promise = {
val p = Promise[B]
p.future.onComplete { _ => unlock() }
p
}
if (!isPaused.get() && lock.compareAndSet(None, Some(promise))) {
runRelease(promise)
} else {
// Do nothing, the submitted arguments will be handled
// by a separate request.
}
}
private def runRelease(): Unit = {
private def runRelease(p: Promise[B]): Unit = {
// Pre-condition: lock is acquired.
// Pos-condition:
// - lock is released
Expand All @@ -128,37 +147,45 @@ final class BatchedFunction[A, B](
this.current.set(result)
val resultF = for {
result <- result.future
_ <- Future {
callbacks.foreach(cb => cb())
}
_ <- Future { callbacks.foreach(cb => cb()) }
} yield result
resultF.onComplete { response =>
unlock()
requests.foreach(_.result.complete(response))
resultF.onComplete(p.tryComplete)
p.future.onComplete {
case Failure(_: BatchedFunction.BatchedFunctionCancelation) =>
result.cancel()
requests.foreach(_.result.complete(defaultResult))
case result =>
requests.foreach(_.result.complete(result))
}
} else {
unlock()
p.tryFailure(new BatchedFunction.BatchedFunctionCancelation)
}
} catch {
case NonFatal(e) =>
unlock()
requests.foreach(_.result.failure(e))
requests.foreach(_.result.tryFailure(e))
scribe.error(s"Unexpected error releasing buffered job", e)
}
}

def defaultResult: Try[B] =
default.map(Success(_)).getOrElse(Failure(new CancellationException))
}

object BatchedFunction {
def fromFuture[A, B](
fn: Seq[A] => Future[B],
functionId: String,
shouldLogQueue: Boolean = false,
default: Option[B] = None,
)(implicit
ec: ExecutionContext
): BatchedFunction[A, B] =
new BatchedFunction(
fn.andThen(CancelableFuture(_)),
functionId,
shouldLogQueue,
default,
)
class BatchedFunctionCancelation extends RuntimeException
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import scala.concurrent.ExecutionContextExecutorService
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.reflect.ClassTag
import scala.util.Success
import scala.util.Try

import scala.meta.internal.builds.MillBuildTool
Expand Down Expand Up @@ -74,7 +75,6 @@ class BuildServerConnection private (

private val ongoingRequests =
new MutableCancelable().addAll(initialConnection.cancelables)
private val ongoingCompilations = new MutableCancelable()

def version: String = _version.get()

Expand Down Expand Up @@ -155,7 +155,6 @@ class BuildServerConnection private (
def compile(params: CompileParams): CompletableFuture[CompileResult] = {
register(
server => server.buildTargetCompile(params),
isCompile = true,
onFail = Some(
(
new CompileResult(StatusCode.CANCELLED),
Expand Down Expand Up @@ -300,14 +299,9 @@ class BuildServerConnection private (
override def cancel(): Unit = {
if (cancelled.compareAndSet(false, true)) {
ongoingRequests.cancel()
ongoingCompilations.cancel()
}
}

def cancelCompilations(): Unit = {
ongoingCompilations.cancel()
}

private def askUser(
original: Future[BuildServerConnection.LauncherConnection]
): Future[BuildServerConnection.LauncherConnection] = {
Expand Down Expand Up @@ -357,24 +351,23 @@ class BuildServerConnection private (
private def register[T: ClassTag](
action: MetalsBuildServer => CompletableFuture[T],
onFail: => Option[(T, String)] = None,
isCompile: Boolean = false,
): CompletableFuture[T] = {

val localCancelable = new MutableCancelable()
def runWithCanceling(
launcherConnection: BuildServerConnection.LauncherConnection
): Future[T] = {
val resultFuture = action(launcherConnection.server)
val cancelable = Cancelable { () =>
Try(resultFuture.cancel(true))
}
if (isCompile) ongoingCompilations.add(cancelable)
else ongoingRequests.add(cancelable)
ongoingRequests.add(cancelable)
localCancelable.add(cancelable)

val result = resultFuture.asScala

result.onComplete { _ =>
if (isCompile) ongoingCompilations.remove(cancelable)
else ongoingRequests.remove(cancelable)
ongoingRequests.remove(cancelable)
localCancelable.remove(cancelable)
}
result
}
Expand Down Expand Up @@ -411,7 +404,14 @@ class BuildServerConnection private (
Future.failed(new MetalsBspException(name, t))
})
}
CancelTokens.future(_ => actionFuture)

CancelTokens.future { token =>
token.onCancel().asScala.onComplete {
case Success(java.lang.Boolean.TRUE) => localCancelable.cancel()
case _ =>
}
actionFuture
}
}

def isBuildServerResponsive: Future[Option[Boolean]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ final class Compilations(
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "compileBatch", shouldLogQueue = true)
](compile, "compileBatch", shouldLogQueue = true, Some(Map.empty))
private val cascadeBatch =
new BatchedFunction[
b.BuildTargetIdentifier,
Map[BuildTargetIdentifier, b.CompileResult],
](compile, "cascadeBatch", shouldLogQueue = true)
](compile, "cascadeBatch", shouldLogQueue = true, Some(Map.empty))
def pauseables: List[Pauseable] = List(compileBatch, cascadeBatch)

private val isCompiling = TrieMap.empty[b.BuildTargetIdentifier, Boolean]
Expand Down Expand Up @@ -115,15 +115,6 @@ final class Compilations(
def cancel(): Unit = {
cascadeBatch.cancelAll()
compileBatch.cancelAll()
buildTargets.all
.flatMap { target =>
buildTargets.buildServerOf(target.getId())
}
.distinct
.foreach { conn =>
conn.cancelCompilations()
}

}

def recompileAll(): Future[Unit] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2202,6 +2202,8 @@ class MetalsLspService(
}

def disconnectOldBuildServer(): Future[Unit] = {
compilations.cancel()
buildTargetClasses.cancel()
diagnostics.reset()
bspSession.foreach(connection =>
scribe.info(s"Disconnecting from ${connection.main.name} session...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ final class BuildTargetClasses(
: TrieMap[b.BuildTargetIdentifier, b.JvmEnvironmentItem] =
TrieMap.empty[b.BuildTargetIdentifier, b.JvmEnvironmentItem]
val rebuildIndex: BatchedFunction[b.BuildTargetIdentifier, Unit] =
BatchedFunction.fromFuture(fetchClasses, "buildTargetClasses")
BatchedFunction.fromFuture(
fetchClasses,
"buildTargetClasses",
default = Some(()),
)

def classesOf(target: b.BuildTargetIdentifier): Classes = {
index.getOrElse(target, new Classes)
Expand Down Expand Up @@ -175,6 +179,10 @@ final class BuildTargetClasses(
val name = NameTransformer.decode(names.last)
descriptors.map(descriptor => Symbols.Global(prefix, descriptor(name)))
}

def cancel(): Unit = {
rebuildIndex.cancelAll()
}
}

sealed abstract class TestFramework(val canResolveChildren: Boolean)
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/src/test/scala/tests/BatchedFunctionSuite.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
package tests

import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.util.Success

import scala.meta.internal.metals.BatchedFunction
import scala.meta.internal.metals.Cancelable
import scala.meta.internal.metals.CancelableFuture

class BatchedFunctionSuite extends BaseSuite {
test("batch") {
Expand Down Expand Up @@ -95,11 +99,51 @@ class BatchedFunctionSuite extends BaseSuite {

mkString.unpause()

assertDiffEqual(paused.value, None)
assertDiffEqual(paused2.value, None)

val unpaused2 = mkString(List("a", "b"))
assertDiffEqual(unpaused2.value, Some(Success("ab")))
for {
_ <- paused.failed
_ <- paused2.failed
res <- mkString(List("a", "b"))
_ = assertEquals(res, "ab")
} yield ()
}

test("cancel2") {
val executorService = Executors.newFixedThreadPool(10)
val ec2 = ExecutionContext.fromExecutor(executorService)
var i = 1
val stuckExample: BatchedFunction[String, String] =
new BatchedFunction(
(seq: Seq[String]) => {
seq.toList match {
case "loop" :: Nil =>
val future = Future.apply {
while (i == 1) {
Thread.sleep(1)
}
"loop-result"
}(ec2)
CancelableFuture[String](future, Cancelable { () => i = 2 })
case _ =>
CancelableFuture[String](
Future.successful("result"),
Cancelable.empty,
)
}
},
"stuck example",
default = Some("default"),
)(ec2)
val cancelled = stuckExample("loop")
assertEquals(i, 1)
assert(cancelled.value.isEmpty)
val normal = stuckExample("normal")
stuckExample.cancelCurrent()
for {
str <- cancelled
_ = assertEquals(i, 2)
_ = assertEquals(str, "default")
str <- normal
_ = assertEquals(str, "result")
} yield ()
}
}
Loading

0 comments on commit 60a1d16

Please sign in to comment.