In MiniLM JAX example do not hardcode device (#1385)

* In MiniLM JAX example do not hardcode device

* In MiniLM JAX example don't use bytecode MLIR

---------

Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
This commit is contained in:
Boian Petkantchin
2023-05-03 10:34:42 -07:00
committed by GitHub
parent 4cfba153d2
commit eba4d06405
2 changed files with 3 additions and 7 deletions

View File

@@ -39,18 +39,14 @@ def get_sample_input():
def export_to_mlir(sample_input: NumpyTree):
model = FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
model_mlir = jax.jit(model).lower(**sample_input).compiler_ir()
byte_stream = io.BytesIO()
model_mlir.operation.write_bytecode(file=byte_stream)
return byte_stream.getvalue()
return str(model_mlir).encode()
sample_input = get_sample_input()
mlir = export_to_mlir(sample_input)
# Compile and load module.
shark_inference = SharkInference(
mlir_module=mlir, mlir_dialect="mhlo", device="cpu"
)
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
shark_inference.compile()
# Run main function.

View File

@@ -1,5 +1,5 @@
flax
jax
jax[cpu]
nodai-SHARK
transformers
torch