diff --git a/compiler/src/main/scala/edg/compiler/Compiler.scala b/compiler/src/main/scala/edg/compiler/Compiler.scala index cb7c025ca..b8d6788a4 100644 --- a/compiler/src/main/scala/edg/compiler/Compiler.scala +++ b/compiler/src/main/scala/edg/compiler/Compiler.scala @@ -117,7 +117,7 @@ class AssignNamer() { } object Compiler { - final val kExpectedProtoVersion = 5 + final val kExpectedProtoVersion = 6 } /** Compiler for a particular design, with an associated library to elaborate references from. diff --git a/compiler/src/main/scala/edg/compiler/CompilerServerMain.scala b/compiler/src/main/scala/edg/compiler/CompilerServerMain.scala index da0a005ca..97a6868c0 100644 --- a/compiler/src/main/scala/edg/compiler/CompilerServerMain.scala +++ b/compiler/src/main/scala/edg/compiler/CompilerServerMain.scala @@ -1,51 +1,28 @@ package edg.compiler -import edg.util.{Errorable, StreamUtils} +import edg.util.{Errorable, QueueStream} import edgrpc.compiler.{compiler => edgcompiler} import edgrpc.compiler.compiler.{CompilerRequest, CompilerResult} import edgrpc.hdl.{hdl => edgrpc} import edg.wir.{DesignPath, IndirectDesignPath, Refinements} -import edgir.elem.elem -import edgir.ref.ref -import edgir.schema.schema -import java.io.{File, PrintWriter, StringWriter} +import java.io.{PrintWriter, StringWriter} -// a PythonInterface that uses the on-event hooks to forward stderr and stdout -// without this, the compiler can freeze on large stdout/stderr data, possibly because of queue sizing -class ForwardingPythonInterface(pythonPaths: Seq[String] = Seq()) - extends PythonInterface(pythonPaths = pythonPaths) { - def forwardProcessOutput(): Unit = { - StreamUtils.forAvailable(processOutputStream) { data => - System.out.print(new String(data)) - System.out.flush() - } - StreamUtils.forAvailable(processErrorStream) { data => - System.err.print(new String(data)) - System.err.flush() - } - } +/** A python interface that uses the host stdio - where the host process 'flips' role and serves as the HDL server while + * compilation is running + */ +class HostPythonInterface extends ProtobufInterface { + protected val stdoutStream = new QueueStream() + val outputStream = stdoutStream.getReader - override protected def onLibraryRequestComplete( - element: ref.LibraryPath, - result: Errorable[(schema.Library.NS.Val, Option[edgrpc.Refinements])] - ): Unit = { - forwardProcessOutput() - } + protected val outputDeserializer = + new ProtobufStreamDeserializer[edgrpc.HdlResponse](System.in, edgrpc.HdlResponse, stdoutStream) + protected val outputSerializer = new ProtobufStreamSerializer[edgrpc.HdlRequest](System.out) - override protected def onElaborateGeneratorRequestComplete( - element: ref.LibraryPath, - values: Map[ref.LocalPath, ExprValue], - result: Errorable[elem.HierarchyBlock] - ): Unit = { - forwardProcessOutput() - } + override def write(message: edgrpc.HdlRequest): Unit = outputSerializer.write(message) - override protected def onRunBackendComplete( - backend: ref.LibraryPath, - result: Errorable[Map[DesignPath, String]] - ): Unit = { - forwardProcessOutput() + override def read(): edgrpc.HdlResponse = { + outputDeserializer.read() } } @@ -94,34 +71,37 @@ object CompilerServerMain { } def main(args: Array[String]): Unit = { - val pyIf = new ForwardingPythonInterface() - (pyIf.getProtoVersion() match { - case Errorable.Success(pyVersion) if pyVersion == Compiler.kExpectedProtoVersion => None - case Errorable.Success(pyMismatchVersion) => Some(pyMismatchVersion.toString) - case Errorable.Error(errMsg) => Some(s"error $errMsg") - }).foreach { pyMismatchVersion => - System.err.println(f"WARNING: Python / compiler version mismatch, Python reported $pyMismatchVersion, " + - f"expected ${Compiler.kExpectedProtoVersion}.") - System.err.println(f"If you get unexpected errors or results, consider updating the Python library or compiler.") - Thread.sleep(kHdlVersionMismatchDelayMs) - } + val pyLib = new PythonInterfaceLibrary() // allow the library cache to persist across requests + while (true) { // handle multiple requests sequentially in the same process + val expectedMagicByte = System.in.read() + if (expectedMagicByte == -1) { + System.exit(0) // end of stream, shut it down + } + require(expectedMagicByte == ProtobufStdioSubprocess.kHeaderMagicByte) + val request = edgcompiler.CompilerRequest.parseDelimitedFrom(System.in) + + val protoInterface = new HostPythonInterface() + val compilerInterface = new PythonInterface(protoInterface) + (compilerInterface.getProtoVersion() match { + case Errorable.Success(pyVersion) if pyVersion == Compiler.kExpectedProtoVersion => None + case Errorable.Success(pyMismatchVersion) => Some(pyMismatchVersion.toString) + case Errorable.Error(errMsg) => Some(s"error $errMsg") + }).foreach { pyMismatchVersion => + System.err.println(f"WARNING: Python / compiler version mismatch, Python reported $pyMismatchVersion, " + + f"expected ${Compiler.kExpectedProtoVersion}.") + System.err.println( + f"If you get unexpected errors or results, consider updating the Python library or compiler." + ) + Thread.sleep(kHdlVersionMismatchDelayMs) + } - val pyLib = new PythonInterfaceLibrary() - pyLib.withPythonInterface(pyIf) { - while (true) { - val expectedMagicByte = System.in.read() - require(expectedMagicByte == ProtobufStdioSubprocess.kHeaderMagicByte || expectedMagicByte < 0) + pyLib.withPythonInterface(compilerInterface) { + val result = compile(request.get, pyLib) - val request = edgcompiler.CompilerRequest.parseDelimitedFrom(System.in) - val result = request match { - case Some(request) => - compile(request, pyLib) - case None => // end of stream - System.exit(0) - throw new NotImplementedError() // provides a return type, shouldn't ever happen - } + assert(protoInterface.outputStream.available() == 0, "unhandled in-band data from HDL compiler") - pyIf.forwardProcessOutput() // in case the hooks didn't catch everything + // this acts as a message indicating end of compilation + protoInterface.write(edgrpc.HdlRequest()) System.out.write(ProtobufStdioSubprocess.kHeaderMagicByte) result.writeDelimitedTo(System.out) diff --git a/compiler/src/main/scala/edg/compiler/PythonInterface.scala b/compiler/src/main/scala/edg/compiler/PythonInterface.scala index 0bda10751..9eab31aac 100644 --- a/compiler/src/main/scala/edg/compiler/PythonInterface.scala +++ b/compiler/src/main/scala/edg/compiler/PythonInterface.scala @@ -9,7 +9,7 @@ import edgir.ref.ref import edgir.schema.schema import edgrpc.hdl.{hdl => edgrpc} -import java.io.{File, InputStream} +import java.io.{File, InputStream, OutputStream} import scala.collection.mutable class ProtobufSubprocessException(msg: String) extends Exception(msg) @@ -18,113 +18,105 @@ object ProtobufStdioSubprocess { val kHeaderMagicByte = 0xfe } -class ProtobufStdioSubprocess[RequestType <: scalapb.GeneratedMessage, ResponseType <: scalapb.GeneratedMessage]( - responseType: scalapb.GeneratedMessageCompanion[ResponseType], - pythonPaths: Seq[String] = Seq(), - args: Seq[String] = Seq() +class ProtobufStreamDeserializer[MessageType <: scalapb.GeneratedMessage]( + stream: InputStream, // stream from the process + messageType: scalapb.GeneratedMessageCompanion[MessageType], + stdoutStream: OutputStream // where in-band non-protobuf data (eg, printfs) are written ) { - protected val process: Either[Process, Throwable] = - try { - val processBuilder = new ProcessBuilder(args: _*) - if (pythonPaths.nonEmpty) { - val env = processBuilder.environment() - val pythonPathString = pythonPaths.mkString(";") - Option(env.get("PYTHONPATH")) match { // merge existing PYTHONPATH if exists - case None => env.put("PYTHONPATH", pythonPathString) - case Some(envPythonPath) => env.put("PYTHONPATH", envPythonPath + ";" + pythonPathString) - } - } - Left(processBuilder.start()) - } catch { - case e: Throwable => Right(e) // if it fails store the exception to be thrown when we can + // deserializes and returns the next Proto message, writing any non-protobuf data to stdoutStream + def read(): MessageType = { + val lastByte = readStdout() + if (lastByte != ProtobufStdioSubprocess.kHeaderMagicByte) { + throw new ProtobufSubprocessException(s"unexpected end of stream, got $lastByte") } - - // this provides a consistent Stream interface for both stdout and stderr - // don't use PipedInputStream since it has a non-expanding buffer and is not single-thread safe - val outputStream = new QueueStream() - val errorStream: InputStream = process match { // the raw error stream from the process - case Left(process) => process.getErrorStream - case Right(_) => new QueueStream() // empty queue if the process never started + messageType.parseDelimitedFrom(stream).get } - protected def readStreamAvailable(stream: InputStream): String = { - var available = stream.available() - val outputBuilder = new mutable.StringBuilder() - while (available > 0) { - val array = new Array[Byte](available) - stream.read(array) - outputBuilder.append(new String(array)) - available = stream.available() + // writes non-protobuf data to stdoutStream, or when readAll is true dumps all remaining data in the stream + // returns the last byte read, including -1 if the end-of-stream was reached + def readStdout(readAll: Boolean = false): Integer = { + var nextByte = stream.read() + while (nextByte >= 0) { + if (nextByte == ProtobufStdioSubprocess.kHeaderMagicByte && !readAll) { + return nextByte + } else { + stdoutStream.write(nextByte) + } + nextByte = stream.read() } - outputBuilder.toString + return nextByte } +} - def write(message: RequestType): Unit = { - process match { - case Right(err) => - throw err - case Left(process) if !process.isAlive => - throw new ProtobufSubprocessException("process died, " + - s"buffered out=${readStreamAvailable(outputStream)}, err=${readStreamAvailable(errorStream)}") - case Left(process) => - process.getOutputStream.write(ProtobufStdioSubprocess.kHeaderMagicByte) - message.writeDelimitedTo(process.getOutputStream) - process.getOutputStream.flush() +class ProtobufStreamSerializer[MessageType <: scalapb.GeneratedMessage](stream: OutputStream) { + def write(message: MessageType): Unit = { + stream.write(ProtobufStdioSubprocess.kHeaderMagicByte) + message.writeDelimitedTo(stream) + stream.flush() + } +} + +trait ProtobufInterface { + def write(message: edgrpc.HdlRequest): Unit + def read(): edgrpc.HdlResponse +} + +class ProtobufStdioSubprocess( + interpreter: String = "python", + pythonPaths: Seq[String] = Seq() +) extends ProtobufInterface { + private val submoduleSearchPaths = if (pythonPaths.nonEmpty) pythonPaths else Seq(".") + private val isSubmoduled = + submoduleSearchPaths.map { searchPath => // check if submoduled, if so prepend the submodule name + new File(new File(searchPath), "PolymorphicBlocks/edg/hdl_server/__init__.py").exists() + }.exists(identity) + val packagePrefix = if (isSubmoduled) "PolymorphicBlocks." else "" + private val packageName = packagePrefix + "edg.hdl_server" + + protected val process: Process = { + val processBuilder = new ProcessBuilder(interpreter, "-u", "-m", packageName) + if (pythonPaths.nonEmpty) { + val env = processBuilder.environment() + val pythonPathString = pythonPaths.mkString(";") + Option(env.get("PYTHONPATH")) match { // merge existing PYTHONPATH if exists + case None => env.put("PYTHONPATH", pythonPathString) + case Some(envPythonPath) => env.put("PYTHONPATH", envPythonPath + ";" + pythonPathString) + } } + processBuilder.start() } - def read(): ResponseType = { - process match { - case Right(err) => - throw err - case Left(process) if !process.isAlive => - throw new ProtobufSubprocessException("process died, " + - s"buffered out=${readStreamAvailable(outputStream)}, err=${readStreamAvailable(errorStream)}") - case Left(process) => - var doneReadingStdout: Boolean = false - while (!doneReadingStdout) { - val nextByte = process.getInputStream.read() - if (nextByte == ProtobufStdioSubprocess.kHeaderMagicByte) { - doneReadingStdout = true - } else if (nextByte < 0) { - throw new ProtobufSubprocessException(s"unexpected end of stream, got $nextByte, " + - s"buffered out=${readStreamAvailable(outputStream)}, err=${readStreamAvailable(errorStream)}") - } else { - outputStream.write(nextByte) - } - } - responseType.parseDelimitedFrom(process.getInputStream).get + // this provides a consistent Stream interface for both stdout and stderr + // don't use PipedInputStream since it has a non-expanding buffer and is not single-thread safe + protected val stdoutStream = new QueueStream() + val outputStream = stdoutStream.getReader + val errorStream: InputStream = process.getErrorStream + + protected val outputDeserializer = + new ProtobufStreamDeserializer[edgrpc.HdlResponse](process.getInputStream, edgrpc.HdlResponse, stdoutStream) + protected val outputSerializer = new ProtobufStreamSerializer[edgrpc.HdlRequest](process.getOutputStream) + + override def write(message: edgrpc.HdlRequest): Unit = outputSerializer.write(message) + + override def read(): edgrpc.HdlResponse = { + if (!process.isAlive) { + throw new ProtobufSubprocessException("process died") } + outputDeserializer.read() } // Shuts down the stream and returns the exit value def shutdown(): Int = { - process match { - case Right(_) => -1 // give a generic failed value, otherwise doesn't need to do anything - case Left(process) => - process.getOutputStream.close() - var doneReadingStdout: Boolean = false - while (!doneReadingStdout) { - val nextByte = process.getInputStream.read() - require(nextByte != ProtobufStdioSubprocess.kHeaderMagicByte) - if (nextByte < 0) { - doneReadingStdout = true - } else { - outputStream.write(nextByte) - } - } - - process.waitFor() - process.exitValue() - } + process.getOutputStream.close() + process.waitFor() + outputDeserializer.readStdout() + process.exitValue() } // Forces a shutdown even if the process is busy def destroy(): Unit = { - process match { - case Right(_) => // ignored - case Left(process) => process.destroyForcibly() - } + process.destroyForcibly() + outputDeserializer.readStdout() } } @@ -133,51 +125,13 @@ class ProtobufStdioSubprocess[RequestType <: scalapb.GeneratedMessage, ResponseT * * This invokes "python -m edg.hdl_server", using either the local or global (pip) module as available. */ -class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = Seq()) { - private val submoduleSearchPaths = if (pythonPaths.nonEmpty) pythonPaths else Seq(".") - private val isSubmoduled = - submoduleSearchPaths.map { searchPath => // check if submoduled, if so prepend the submodule name - new File(new File(searchPath), "PolymorphicBlocks/edg/hdl_server/__init__.py").exists() - }.exists(identity) - val packagePrefix = if (isSubmoduled) "PolymorphicBlocks." else "" - private val packageName = packagePrefix + "edg.hdl_server" - - private val command = Seq(interpreter, "-u", "-m", packageName) - protected val process = new ProtobufStdioSubprocess[edgrpc.HdlRequest, edgrpc.HdlResponse]( - edgrpc.HdlResponse, - pythonPaths, - command - ) - val processOutputStream: InputStream = process.outputStream - val processErrorStream: InputStream = process.errorStream - - def shutdown(): Int = { - process.shutdown() - } - - def destroy(): Unit = { - process.destroy() - } - - // Hooks to implement when certain actions happen - protected def onLibraryRequest(element: ref.LibraryPath): Unit = {} - protected def onLibraryRequestComplete( - element: ref.LibraryPath, - result: Errorable[(schema.Library.NS.Val, Option[edgrpc.Refinements])] - ): Unit = {} - protected def onElaborateGeneratorRequest(element: ref.LibraryPath, values: Map[ref.LocalPath, ExprValue]): Unit = {} - protected def onElaborateGeneratorRequestComplete( - element: ref.LibraryPath, - values: Map[ref.LocalPath, ExprValue], - result: Errorable[elem.HierarchyBlock] - ): Unit = {} - +class PythonInterface(interface: ProtobufInterface) { def getProtoVersion(): Errorable[Int] = { val (reply, reqTime) = timeExec { - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.GetProtoVersion(0) // dummy argument )) - process.read() + interface.read() } reply.response match { case edgrpc.HdlResponse.Response.GetProtoVersion(result) => @@ -192,10 +146,10 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = def indexModule(module: String): Errorable[Seq[ref.LibraryPath]] = { val request = edgrpc.ModuleName(module) val (reply, reqTime) = timeExec { - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.IndexModule(value = request) )) - process.read() + interface.read() } reply.response match { case edgrpc.HdlResponse.Response.IndexModule(result) => @@ -207,6 +161,13 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = } } + // Hooks to implement when certain actions happen + protected def onLibraryRequest(element: ref.LibraryPath): Unit = {} + protected def onLibraryRequestComplete( + element: ref.LibraryPath, + result: Errorable[(schema.Library.NS.Val, Option[edgrpc.Refinements])] + ): Unit = {} + def libraryRequest(element: ref.LibraryPath): Errorable[(schema.Library.NS.Val, Option[edgrpc.Refinements])] = { onLibraryRequest(element) @@ -214,10 +175,10 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = element = Some(element) ) val (reply, reqTime) = timeExec { // TODO plumb refinements through - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.GetLibraryElement(value = request) )) - process.read() + interface.read() } val result = reply.response match { case edgrpc.HdlResponse.Response.GetLibraryElement(result) => @@ -231,6 +192,13 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = result } + protected def onElaborateGeneratorRequest(element: ref.LibraryPath, values: Map[ref.LocalPath, ExprValue]): Unit = {} + protected def onElaborateGeneratorRequestComplete( + element: ref.LibraryPath, + values: Map[ref.LocalPath, ExprValue], + result: Errorable[elem.HierarchyBlock] + ): Unit = {} + def elaborateGeneratorRequest( element: ref.LibraryPath, values: Map[ref.LocalPath, ExprValue] @@ -247,10 +215,10 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = }.toSeq ) val (reply, reqTime) = timeExec { - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.ElaborateGenerator(value = request) )) - process.read() + interface.read() } val result = reply.response match { case edgrpc.HdlResponse.Response.ElaborateGenerator(result) => @@ -290,10 +258,10 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = }.toSeq ) val (reply, reqTime) = timeExec { - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.RunRefinement(value = request) )) - process.read() + interface.read() } val result = reply.response match { case edgrpc.HdlResponse.Response.RunRefinement(result) => @@ -326,10 +294,10 @@ class PythonInterface(interpreter: String = "python", pythonPaths: Seq[String] = arguments = arguments ) val (reply, reqTime) = timeExec { - process.write(edgrpc.HdlRequest( + interface.write(edgrpc.HdlRequest( request = edgrpc.HdlRequest.Request.RunBackend(value = request) )) - process.read() + interface.read() } val result = reply.response match { case edgrpc.HdlResponse.Response.RunBackend(result) => diff --git a/compiler/src/main/scala/edg/util/QueueStream.scala b/compiler/src/main/scala/edg/util/QueueStream.scala index 933ad6b28..12240de7e 100644 --- a/compiler/src/main/scala/edg/util/QueueStream.scala +++ b/compiler/src/main/scala/edg/util/QueueStream.scala @@ -1,6 +1,6 @@ package edg.util -import java.io.InputStream +import java.io.{InputStream, OutputStream} import collection.mutable /** Why the heck are we writing another QueueStream when we have things like Apache QueueInputStream? @@ -11,12 +11,14 @@ import collection.mutable * * So here, yet another variation of Stream. Yay. */ -class QueueStream extends InputStream { +class QueueStream extends OutputStream { protected val queue = mutable.Queue[Byte]() - override def read(): Int = queue.dequeue() - override def available(): Int = queue.length + override def write(data: Int): Unit = queue.enqueue(data.toByte) - // don't want to bother writing a separate OutputStream version, so the write methods are just stuffed in here - def write(data: Int): Unit = queue.enqueue(data.toByte) + class Reader extends InputStream { + override def read(): Int = queue.dequeue() + override def available(): Int = queue.length + } + def getReader: Reader = new Reader } diff --git a/compiler/src/test/scala/edg/compiler/PythonInterfaceTest.scala b/compiler/src/test/scala/edg/compiler/PythonInterfaceTest.scala index 23bac74cb..69176becb 100644 --- a/compiler/src/test/scala/edg/compiler/PythonInterfaceTest.scala +++ b/compiler/src/test/scala/edg/compiler/PythonInterfaceTest.scala @@ -12,8 +12,9 @@ class PythonInterfaceTest extends AnyFlatSpec { val compiledDir = new File(getClass.getResource("").getPath) // above returns compiler/target/scala-2.xx/test-classes/edg/compiler, get the root repo dir val repoDir = compiledDir.getParentFile.getParentFile.getParentFile.getParentFile.getParentFile.getParentFile - val pyIf = new PythonInterface(pythonPaths = Seq(repoDir.getAbsolutePath)) + val pyProcess = new ProtobufStdioSubprocess(pythonPaths = Seq(repoDir.getAbsolutePath)) + val pyIf = new PythonInterface(pyProcess) pyIf.indexModule("edg.core").getClass should equal(classOf[Errorable.Success[Seq[LibraryPath]]]) - pyIf.shutdown() should equal(0) + pyProcess.shutdown() should equal(0) } } diff --git a/edg/core/ScalaCompilerInterface.py b/edg/core/ScalaCompilerInterface.py index 4bef89208..ff0327142 100644 --- a/edg/core/ScalaCompilerInterface.py +++ b/edg/core/ScalaCompilerInterface.py @@ -64,14 +64,11 @@ def append_values(self, values: List[Tuple[edgir.LocalPath, edgir.ValueLit]]): class ScalaCompilerInstance: - kDevRelpath = "../../compiler/target/scala-2.13/edg-compiler-assembly-0.1-SNAPSHOTjar" + kDevRelpath = "../../compiler/target/scala-2.13/edg-compiler-assembly-0.1-SNAPSHOT.jar" kPrecompiledRelpath = "resources/edg-compiler-precompiled.jar" - def __init__(self, *, suppress_stderr: bool = False): + def __init__(self): self.process: Optional[Any] = None - self.suppress_stderr = suppress_stderr - self.request_serializer: Optional[BufferSerializer[edgrpc.CompilerRequest]] = None - self.response_deserializer: Optional[BufferDeserializer[edgrpc.CompilerResult]] = None def check_started(self) -> None: if self.process is None: @@ -89,19 +86,18 @@ def check_started(self) -> None: ['java', '-jar', jar_path], stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE if self.suppress_stderr else None + stderr=subprocess.PIPE ) - assert self.process.stdin is not None - self.request_serializer = BufferSerializer[edgrpc.CompilerRequest](self.process.stdin) - assert self.process.stdout is not None - self.response_deserializer = BufferDeserializer(edgrpc.CompilerResult, self.process.stdout) - def compile(self, block: Type[Block], refinements: Refinements = Refinements(), *, ignore_errors: bool = False) -> CompiledDesign: + from ..hdl_server.__main__ import process_request self.check_started() - assert self.request_serializer is not None - assert self.response_deserializer is not None + + assert self.process is not None + assert self.process.stdin is not None + assert self.process.stdout is not None + request_serializer = BufferSerializer[edgrpc.CompilerRequest](self.process.stdin) block_obj = block() request = edgrpc.CompilerRequest( @@ -110,13 +106,31 @@ def compile(self, block: Type[Block], refinements: Refinements = Refinements(), ) if isinstance(block_obj, DesignTop): refinements = block_obj.refinements() + refinements - refinements.populate_proto(request.refinements) - self.request_serializer.write(request) - result = self.response_deserializer.read() - - sys.stdout.buffer.write(self.response_deserializer.read_stdout()) + # write the initial request to the compiler process + request_serializer.write(request) + + # until the compiler gives back the response, this acts as the HDL server, + # taking requests in the opposite direction + assert self.process.stdin is not None + assert self.process.stdout is not None + hdl_request_deserializer = BufferDeserializer(edgrpc.HdlRequest, self.process.stdout) + hdl_response_serializer = BufferSerializer[edgrpc.HdlResponse](self.process.stdin) + while True: + sys.stdout.buffer.write(hdl_request_deserializer.read_stdout()) + sys.stdout.buffer.flush() + hdl_request = hdl_request_deserializer.read() + assert hdl_request is not None + hdl_response = process_request(hdl_request) + if hdl_response is None: + break + hdl_response_serializer.write(hdl_response) + + response_deserializer = BufferDeserializer(edgrpc.CompilerResult, self.process.stdout) + result = response_deserializer.read() + + sys.stdout.buffer.write(response_deserializer.read_stdout()) sys.stdout.buffer.flush() assert result is not None @@ -130,8 +144,7 @@ def close(self): assert self.process is not None self.process.stdin.close() self.process.stdout.close() - if self.suppress_stderr: - self.process.stderr.close() + self.process.stderr.close() self.process.wait() diff --git a/edg/core/resources/edg-compiler-precompiled.jar b/edg/core/resources/edg-compiler-precompiled.jar index f68877f00..f4e972503 100644 Binary files a/edg/core/resources/edg-compiler-precompiled.jar and b/edg/core/resources/edg-compiler-precompiled.jar differ diff --git a/edg/core/test_generator.py b/edg/core/test_generator.py index e9ae4f2ba..1d8504af2 100644 --- a/edg/core/test_generator.py +++ b/edg/core/test_generator.py @@ -1,7 +1,9 @@ import unittest +from os import devnull +from contextlib import redirect_stderr from . import * -from .ScalaCompilerInterface import ScalaCompiler, ScalaCompilerInstance +from .ScalaCompilerInterface import ScalaCompiler class TestGeneratorAssign(Block): @@ -205,8 +207,6 @@ def helperfn() -> None: class GeneratorFailureTestCase(unittest.TestCase): def test_metadata(self) -> None: - # if we don't suppress the output, the error from the generator propagates to the test console - compiler = ScalaCompilerInstance(suppress_stderr=True) - with self.assertRaises(CompilerCheckError) as context: - compiler.compile(TestGeneratorFailure) - compiler.close() # if we don't close it, we get a ResourceWarning + with self.assertRaises(CompilerCheckError), \ + open(devnull, 'w') as fnull, redirect_stderr(fnull): # suppress generator error + self.compiled = ScalaCompiler.compile(TestGeneratorFailure) diff --git a/edg/hdl_server/__main__.py b/edg/hdl_server/__main__.py index d1b80219f..e0dcdbd01 100644 --- a/edg/hdl_server/__main__.py +++ b/edg/hdl_server/__main__.py @@ -2,7 +2,7 @@ import inspect import sys from types import ModuleType -from typing import Set, Type, Tuple, TypeVar, cast +from typing import Set, Type, Tuple, TypeVar, cast, Optional from .. import edgir from .. import edgrpc @@ -10,7 +10,7 @@ from ..core.Core import NonLibraryProperty -EDG_PROTO_VERSION = 5 +EDG_PROTO_VERSION = 6 class LibraryElementIndexer: @@ -82,6 +82,71 @@ def class_from_library(elt: edgir.LibraryPath, expected_superclass: Type[Library return cls +def process_request(request: edgrpc.HdlRequest) -> Optional[edgrpc.HdlResponse]: + response = edgrpc.HdlResponse() + try: + if request.HasField('index_module'): + module = importlib.import_module(request.index_module.name) + library = LibraryElementIndexer() + indexed = [edgir.LibraryPath(target=edgir.LocalStep(name=indexed._static_def_name())) + for indexed in library.index_module(module)] + response.index_module.indexed.extend(indexed) + elif request.HasField('get_library_element'): + cls = class_from_library(request.get_library_element.element, + LibraryElement) # type: ignore + obj, obj_proto = elaborate_class(cls) + + response.get_library_element.element.CopyFrom(obj_proto) + if isinstance(obj, DesignTop): + obj.refinements().populate_proto(response.get_library_element.refinements) + elif request.HasField('elaborate_generator'): + generator_type = class_from_library(request.elaborate_generator.element, + GeneratorBlock) + generator_obj = generator_type() + + response.elaborate_generator.generated.CopyFrom(builder.elaborate_toplevel( + generator_obj, + is_generator=True, + generate_values=[(value.path, value.value) for value in request.elaborate_generator.values])) + elif request.HasField('run_refinement'): + refinement_pass_class = class_from_library(request.run_refinement.refinement_pass, + BaseRefinementPass) # type: ignore + refinement_pass = refinement_pass_class() + + refinement_results = refinement_pass.run( + CompiledDesign.from_request(request.run_refinement.design, request.run_refinement.solvedValues)) + response.run_refinement.SetInParent() + for path, refinement_result in refinement_results: + new_value = response.run_refinement.newValues.add() + new_value.path.CopyFrom(path) + new_value.value.CopyFrom(refinement_result) + elif request.HasField('run_backend'): + backend_class = class_from_library(request.run_backend.backend, + BaseBackend) # type: ignore + backend = backend_class() + + backend_results = backend.run( + CompiledDesign.from_request(request.run_backend.design, request.run_backend.solvedValues), + dict(request.run_backend.arguments)) + response.run_backend.SetInParent() + for path, backend_result in backend_results: + response_result = response.run_backend.results.add() + response_result.path.CopyFrom(path) + response_result.text = backend_result + elif request.HasField('get_proto_version'): + response.get_proto_version = EDG_PROTO_VERSION + else: + return None + except BaseException as e: + import traceback + # exception formatting from https://stackoverflow.com/questions/4564559/get-exception-description-and-stack-trace-which-caused-an-exception-all-as-a-st + response = edgrpc.HdlResponse() + response.error.error = repr(e) + response.error.traceback = "".join(traceback.TracebackException.from_exception(e).format()) + # also print it, to preserve the usual behavior of errors in Python + traceback.print_exc() + return response + def run_server(): stdin_deserializer = BufferDeserializer(edgrpc.HdlRequest, sys.stdin.buffer) stdout_serializer = BufferSerializer[edgrpc.HdlResponse](sys.stdout.buffer) @@ -91,69 +156,11 @@ def run_server(): if request is None: # end of stream sys.exit(0) - response = edgrpc.HdlResponse() - try: - if request.HasField('index_module'): - module = importlib.import_module(request.index_module.name) - library = LibraryElementIndexer() - indexed = [edgir.LibraryPath(target=edgir.LocalStep(name=indexed._static_def_name())) - for indexed in library.index_module(module)] - response.index_module.indexed.extend(indexed) - elif request.HasField('get_library_element'): - cls = class_from_library(request.get_library_element.element, - LibraryElement) # type: ignore - obj, obj_proto = elaborate_class(cls) - - response.get_library_element.element.CopyFrom(obj_proto) - if isinstance(obj, DesignTop): - obj.refinements().populate_proto(response.get_library_element.refinements) - elif request.HasField('elaborate_generator'): - generator_type = class_from_library(request.elaborate_generator.element, - GeneratorBlock) - generator_obj = generator_type() - - response.elaborate_generator.generated.CopyFrom(builder.elaborate_toplevel( - generator_obj, - is_generator=True, - generate_values=[(value.path, value.value) for value in request.elaborate_generator.values])) - elif request.HasField('run_refinement'): - refinement_pass_class = class_from_library(request.run_refinement.refinement_pass, - BaseRefinementPass) # type: ignore - refinement_pass = refinement_pass_class() - - refinement_results = refinement_pass.run( - CompiledDesign.from_request(request.run_refinement.design, request.run_refinement.solvedValues)) - response.run_refinement.SetInParent() - for path, refinement_result in refinement_results: - new_value = response.run_refinement.newValues.add() - new_value.path.CopyFrom(path) - new_value.value.CopyFrom(refinement_result) - elif request.HasField('run_backend'): - backend_class = class_from_library(request.run_backend.backend, - BaseBackend) # type: ignore - backend = backend_class() - - backend_results = backend.run( - CompiledDesign.from_request(request.run_backend.design, request.run_backend.solvedValues), - dict(request.run_backend.arguments)) - response.run_backend.SetInParent() - for path, backend_result in backend_results: - response_result = response.run_backend.results.add() - response_result.path.CopyFrom(path) - response_result.text = backend_result - elif request.HasField('get_proto_version'): - response.get_proto_version = EDG_PROTO_VERSION - else: - raise RuntimeError(f"Unknown request {request}") - except BaseException as e: - import traceback - # exception formatting from https://stackoverflow.com/questions/4564559/get-exception-description-and-stack-trace-which-caused-an-exception-all-as-a-st - response.error.error = repr(e) - response.error.traceback = "".join(traceback.TracebackException.from_exception(e).format()) - # also print it, to preserve the usual behavior of errors in Python - traceback.print_exc() - - sys.stdout.buffer.write(stdin_deserializer.read_stdout()) + response = process_request(request) + if response is None: + raise RuntimeError(f"Unknown request {request}") + + sys.stdout.buffer.write(stdin_deserializer.read_stdout()) # forward prints and stuff stdout_serializer.write(response)