diff --git a/shark/examples/shark_inference/minilm_jax.py b/shark/examples/shark_inference/minilm_jax.py index c0597688..ee4e80c4 100644 --- a/shark/examples/shark_inference/minilm_jax.py +++ b/shark/examples/shark_inference/minilm_jax.py @@ -39,18 +39,14 @@ def get_sample_input(): def export_to_mlir(sample_input: NumpyTree): model = FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased") model_mlir = jax.jit(model).lower(**sample_input).compiler_ir() - byte_stream = io.BytesIO() - model_mlir.operation.write_bytecode(file=byte_stream) - return byte_stream.getvalue() + return str(model_mlir).encode() sample_input = get_sample_input() mlir = export_to_mlir(sample_input) # Compile and load module. -shark_inference = SharkInference( - mlir_module=mlir, mlir_dialect="mhlo", device="cpu" -) +shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo") shark_inference.compile() # Run main function. diff --git a/shark/examples/shark_inference/minilm_jax_requirements.txt b/shark/examples/shark_inference/minilm_jax_requirements.txt index e6966303..9a2543cc 100644 --- a/shark/examples/shark_inference/minilm_jax_requirements.txt +++ b/shark/examples/shark_inference/minilm_jax_requirements.txt @@ -1,5 +1,5 @@ flax -jax +jax[cpu] nodai-SHARK transformers torch