Files
AMD-SHARK-Studio/shark/examples/shark_dynamo/basic_examples.py

26 lines
346 B
Python

import torch
import shark
def foo(x, a):
if x.shape[0] > 3:
return x + a
else:
return x + 3
shark_options = {"device": "cpu"}
compiled = torch.compile(foo, backend="shark", options=shark_options)
input = torch.ones(4)
x = compiled(input, input)
print(x)
input = torch.ones(3)
x = compiled(input, input)
print(x)