mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-10 22:38:01 -05:00
Add Forward method to SHARKRunner and fix examples. (#1756)
This commit is contained in:
@@ -43,9 +43,7 @@ if __name__ == "__main__":
|
||||
minilm_mlir, func_name = mlir_importer.import_mlir(
|
||||
is_dynamic=False, tracing_required=True
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
minilm_mlir, func_name, mlir_dialect="linalg"
|
||||
)
|
||||
shark_module = SharkInference(minilm_mlir)
|
||||
shark_module.compile()
|
||||
token_logits = torch.tensor(shark_module.forward(inputs))
|
||||
mask_id = torch.where(
|
||||
|
||||
@@ -141,6 +141,10 @@ class SharkInference:
|
||||
def __call__(self, function_name: str, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run(function_name, inputs, send_to_host)
|
||||
|
||||
# forward function.
|
||||
def forward(self, inputs: tuple, send_to_host=True):
|
||||
return self.shark_runner.run("forward", inputs, send_to_host)
|
||||
|
||||
# Get all function names defined within the compiled module.
|
||||
def get_functions_in_module(self):
|
||||
return self.shark_runner.get_functions_in_module()
|
||||
|
||||
Reference in New Issue
Block a user