Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-use host Python process when compiling from command line #372

Merged
merged 12 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/main/scala/edg/compiler/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class AssignNamer() {
}

object Compiler {
final val kExpectedProtoVersion = 5
final val kExpectedProtoVersion = 6
}

/** Compiler for a particular design, with an associated library to elaborate references from.
Expand Down
104 changes: 42 additions & 62 deletions compiler/src/main/scala/edg/compiler/CompilerServerMain.scala
Original file line number Diff line number Diff line change
@@ -1,51 +1,28 @@
package edg.compiler

import edg.util.{Errorable, StreamUtils}
import edg.util.{Errorable, QueueStream}
import edgrpc.compiler.{compiler => edgcompiler}
import edgrpc.compiler.compiler.{CompilerRequest, CompilerResult}
import edgrpc.hdl.{hdl => edgrpc}
import edg.wir.{DesignPath, IndirectDesignPath, Refinements}
import edgir.elem.elem
import edgir.ref.ref
import edgir.schema.schema

import java.io.{File, PrintWriter, StringWriter}
import java.io.{PrintWriter, StringWriter}

// a PythonInterface that uses the on-event hooks to forward stderr and stdout
// without this, the compiler can freeze on large stdout/stderr data, possibly because of queue sizing
class ForwardingPythonInterface(pythonPaths: Seq[String] = Seq())
extends PythonInterface(pythonPaths = pythonPaths) {
def forwardProcessOutput(): Unit = {
StreamUtils.forAvailable(processOutputStream) { data =>
System.out.print(new String(data))
System.out.flush()
}
StreamUtils.forAvailable(processErrorStream) { data =>
System.err.print(new String(data))
System.err.flush()
}
}
/** A python interface that uses the host stdio - where the host process 'flips' role and serves as the HDL server while
* compilation is running
*/
class HostPythonInterface extends ProtobufInterface {
protected val stdoutStream = new QueueStream()
val outputStream = stdoutStream.getReader

override protected def onLibraryRequestComplete(
element: ref.LibraryPath,
result: Errorable[(schema.Library.NS.Val, Option[edgrpc.Refinements])]
): Unit = {
forwardProcessOutput()
}
protected val outputDeserializer =
new ProtobufStreamDeserializer[edgrpc.HdlResponse](System.in, edgrpc.HdlResponse, stdoutStream)
protected val outputSerializer = new ProtobufStreamSerializer[edgrpc.HdlRequest](System.out)

override protected def onElaborateGeneratorRequestComplete(
element: ref.LibraryPath,
values: Map[ref.LocalPath, ExprValue],
result: Errorable[elem.HierarchyBlock]
): Unit = {
forwardProcessOutput()
}
override def write(message: edgrpc.HdlRequest): Unit = outputSerializer.write(message)

override protected def onRunBackendComplete(
backend: ref.LibraryPath,
result: Errorable[Map[DesignPath, String]]
): Unit = {
forwardProcessOutput()
override def read(): edgrpc.HdlResponse = {
outputDeserializer.read()
}
}

Expand Down Expand Up @@ -94,34 +71,37 @@ object CompilerServerMain {
}

def main(args: Array[String]): Unit = {
val pyIf = new ForwardingPythonInterface()
(pyIf.getProtoVersion() match {
case Errorable.Success(pyVersion) if pyVersion == Compiler.kExpectedProtoVersion => None
case Errorable.Success(pyMismatchVersion) => Some(pyMismatchVersion.toString)
case Errorable.Error(errMsg) => Some(s"error $errMsg")
}).foreach { pyMismatchVersion =>
System.err.println(f"WARNING: Python / compiler version mismatch, Python reported $pyMismatchVersion, " +
f"expected ${Compiler.kExpectedProtoVersion}.")
System.err.println(f"If you get unexpected errors or results, consider updating the Python library or compiler.")
Thread.sleep(kHdlVersionMismatchDelayMs)
}
val pyLib = new PythonInterfaceLibrary() // allow the library cache to persist across requests
while (true) { // handle multiple requests sequentially in the same process
val expectedMagicByte = System.in.read()
if (expectedMagicByte == -1) {
System.exit(0) // end of stream, shut it down
}
require(expectedMagicByte == ProtobufStdioSubprocess.kHeaderMagicByte)
val request = edgcompiler.CompilerRequest.parseDelimitedFrom(System.in)

val protoInterface = new HostPythonInterface()
val compilerInterface = new PythonInterface(protoInterface)
(compilerInterface.getProtoVersion() match {
case Errorable.Success(pyVersion) if pyVersion == Compiler.kExpectedProtoVersion => None
case Errorable.Success(pyMismatchVersion) => Some(pyMismatchVersion.toString)
case Errorable.Error(errMsg) => Some(s"error $errMsg")
}).foreach { pyMismatchVersion =>
System.err.println(f"WARNING: Python / compiler version mismatch, Python reported $pyMismatchVersion, " +
f"expected ${Compiler.kExpectedProtoVersion}.")
System.err.println(
f"If you get unexpected errors or results, consider updating the Python library or compiler."
)
Thread.sleep(kHdlVersionMismatchDelayMs)
}

val pyLib = new PythonInterfaceLibrary()
pyLib.withPythonInterface(pyIf) {
while (true) {
val expectedMagicByte = System.in.read()
require(expectedMagicByte == ProtobufStdioSubprocess.kHeaderMagicByte || expectedMagicByte < 0)
pyLib.withPythonInterface(compilerInterface) {
val result = compile(request.get, pyLib)

val request = edgcompiler.CompilerRequest.parseDelimitedFrom(System.in)
val result = request match {
case Some(request) =>
compile(request, pyLib)
case None => // end of stream
System.exit(0)
throw new NotImplementedError() // provides a return type, shouldn't ever happen
}
assert(protoInterface.outputStream.available() == 0, "unhandled in-band data from HDL compiler")

pyIf.forwardProcessOutput() // in case the hooks didn't catch everything
// this acts as a message indicating end of compilation
protoInterface.write(edgrpc.HdlRequest())

System.out.write(ProtobufStdioSubprocess.kHeaderMagicByte)
result.writeDelimitedTo(System.out)
Expand Down
Loading
Loading