From 6a19aff0b644b60a7d8e69630fabbcc438de21e3 Mon Sep 17 00:00:00 2001 From: Yanping Huang Date: Tue, 13 Aug 2024 11:21:13 -0700 Subject: [PATCH] Internal changes PiperOrigin-RevId: 662587900 Change-Id: If0b3204723a13b59d514c2f36161a4e928f661af --- saxml/server/pax/vision/servable_vision_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/saxml/server/pax/vision/servable_vision_model.py b/saxml/server/pax/vision/servable_vision_model.py index 82081d0..b88f6d1 100644 --- a/saxml/server/pax/vision/servable_vision_model.py +++ b/saxml/server/pax/vision/servable_vision_model.py @@ -199,7 +199,7 @@ class TokenToVideoHParams(servable_model_params.ServableMethodParams): num_tokens_per_frame: The number of tokens per frame. """ - image_postprocessor: Optional[Callable[[tf.Tensor], str]] = None + image_postprocessor: Optional[Callable[[tf.Tensor], bytes]] = None model_method_name: Optional[str] = None num_tokens_per_frame: int = 256 @@ -1001,9 +1001,9 @@ def post_processing(self, compute_outputs: NestedNpTensor) -> List[Any]: videos = compute_outputs['video'] # [batch, t, h, w, c] batched_video_bytes = [] for video in videos: - video_bytes: list[str] = [] + video_bytes: list[bytes] = [] for image_frame in video: - # [h, w, c] -> string + # [h, w, c] -> bytes image_bytes = self._image_postprocessor(image_frame) video_bytes.append(image_bytes) batched_video_bytes.append(video_bytes) @@ -1143,7 +1143,7 @@ def init_method( 'Must specify `model_method_name` in VideoToTokenHParams.' ) # TODO(huangyp): Use model-specific dummy input. - image_bytes = tf.image.encode_jpeg(np.ones((256, 256, 3), dtype=np.uint8)) + image_bytes = tf.image.encode_png(np.ones((256, 256, 3), dtype=np.uint8)) dummy_input = {'image_frames': [image_bytes]} return VideoToToken( model,