Skip to content

Commit

Permalink
Use SageMaker SDK's own string serializer/deserializer in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
faph committed Feb 16, 2024
1 parent e4f2351 commit 3d082e6
Showing 1 changed file with 6 additions and 22 deletions.
28 changes: 6 additions & 22 deletions tests/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

from typing import Tuple

import botocore.response
import pytest
import sagemaker.deserializers
import sagemaker.serializers

import inference_server
Expand Down Expand Up @@ -83,26 +81,12 @@ def test_invocations():

def test_prediction_custom_serializer():
"""Test the default plugin again, now using high-level testing.predict"""

class Serializer(sagemaker.serializers.BaseSerializer):
@property
def CONTENT_TYPE(self) -> str:
return "application/octet-stream"

def serialize(self, data: str) -> bytes:
return data.encode() # Simple str to bytes serializer

class Deserializer(sagemaker.deserializers.BaseDeserializer):
@property
def ACCEPT(self) -> Tuple[str]:
return ("application/json",)

def deserialize(self, stream: botocore.response.StreamingBody, content_type: str) -> str:
assert content_type in self.ACCEPT
return stream.read().decode() # Simple bytes to str deserializer

input_data = "What's the shipping forecast for tomorrow" # Simply pass a string
prediction = inference_server.testing.predict(data=input_data, serializer=Serializer(), deserializer=Deserializer())
prediction = inference_server.testing.predict(
data=input_data,
serializer=sagemaker.serializers.StringSerializer(),
deserializer=sagemaker.deserializers.StringDeserializer(),
)
assert prediction == input_data # Receive a string


Expand Down

0 comments on commit 3d082e6

Please sign in to comment.