From eba4d06405b2901feebed307ca7c022a3c4b13bb Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Wed, 3 May 2023 10:34:42 -0700 Subject: [PATCH] In MiniLM JAX example do not hardcode device (#1385) * In MiniLM JAX example do not hardcode device * In MiniLM JAX example don't use bytecode MLIR --------- Co-authored-by: Boian Petkantchin --- shark/examples/shark_inference/minilm_jax.py | 8 ++------ .../examples/shark_inference/minilm_jax_requirements.txt | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/shark/examples/shark_inference/minilm_jax.py b/shark/examples/shark_inference/minilm_jax.py index c0597688..ee4e80c4 100644 --- a/shark/examples/shark_inference/minilm_jax.py +++ b/shark/examples/shark_inference/minilm_jax.py @@ -39,18 +39,14 @@ def get_sample_input(): def export_to_mlir(sample_input: NumpyTree): model = FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased") model_mlir = jax.jit(model).lower(**sample_input).compiler_ir() - byte_stream = io.BytesIO() - model_mlir.operation.write_bytecode(file=byte_stream) - return byte_stream.getvalue() + return str(model_mlir).encode() sample_input = get_sample_input() mlir = export_to_mlir(sample_input) # Compile and load module. -shark_inference = SharkInference( - mlir_module=mlir, mlir_dialect="mhlo", device="cpu" -) +shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo") shark_inference.compile() # Run main function. diff --git a/shark/examples/shark_inference/minilm_jax_requirements.txt b/shark/examples/shark_inference/minilm_jax_requirements.txt index e6966303..9a2543cc 100644 --- a/shark/examples/shark_inference/minilm_jax_requirements.txt +++ b/shark/examples/shark_inference/minilm_jax_requirements.txt @@ -1,5 +1,5 @@ flax -jax +jax[cpu] nodai-SHARK transformers torch