mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
In MiniLM JAX example do not hardcode device (#1385)
* In MiniLM JAX example do not hardcode device * In MiniLM JAX example don't use bytecode MLIR --------- Co-authored-by: Boian Petkantchin <boian@nod-labs.com>
This commit is contained in:
committed by
GitHub
parent
4cfba153d2
commit
eba4d06405
@@ -39,18 +39,14 @@ def get_sample_input():
|
||||
def export_to_mlir(sample_input: NumpyTree):
|
||||
model = FlaxAutoModel.from_pretrained("microsoft/MiniLM-L12-H384-uncased")
|
||||
model_mlir = jax.jit(model).lower(**sample_input).compiler_ir()
|
||||
byte_stream = io.BytesIO()
|
||||
model_mlir.operation.write_bytecode(file=byte_stream)
|
||||
return byte_stream.getvalue()
|
||||
return str(model_mlir).encode()
|
||||
|
||||
|
||||
sample_input = get_sample_input()
|
||||
mlir = export_to_mlir(sample_input)
|
||||
|
||||
# Compile and load module.
|
||||
shark_inference = SharkInference(
|
||||
mlir_module=mlir, mlir_dialect="mhlo", device="cpu"
|
||||
)
|
||||
shark_inference = SharkInference(mlir_module=mlir, mlir_dialect="mhlo")
|
||||
shark_inference.compile()
|
||||
|
||||
# Run main function.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
flax
|
||||
jax
|
||||
jax[cpu]
|
||||
nodai-SHARK
|
||||
transformers
|
||||
torch
|
||||
|
||||
Reference in New Issue
Block a user