mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add the shark compile downstream due to https://github.com/pytorch/pytorch/pull/104185#issuecomment-1615110613 (#1615)
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
from torch._dynamo import register_backend
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_backend
|
||||
def shark(model, inputs, *, options):
|
||||
try:
|
||||
from shark.dynamo_backend.utils import SharkBackend
|
||||
except ImportError:
|
||||
log.exception(
|
||||
"Unable to import SHARK - High Performance Machine Learning Distribution"
|
||||
"Please install the right version of SHARK that matches the PyTorch version being used. "
|
||||
"Refer to https://github.com/nod-ai/SHARK/ for details."
|
||||
)
|
||||
raise
|
||||
return SharkBackend(model, inputs, options)
|
||||
|
||||
|
||||
def has_shark():
|
||||
try:
|
||||
importlib.import_module("shark")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@@ -1,70 +1,25 @@
|
||||
import torch
|
||||
import torch_mlir
|
||||
import torch._dynamo as torchdynamo
|
||||
from shark.sharkdynamo.utils import make_shark_compiler
|
||||
import shark
|
||||
|
||||
|
||||
import warnings, logging
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
torchdynamo.config.log_level = logging.ERROR
|
||||
def foo(x, a):
|
||||
if x.shape[0] > 3:
|
||||
return x + a
|
||||
else:
|
||||
return x + 3
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
shark_options = {"device": "cpu"}
|
||||
compiled = torch.compile(foo, backend="shark", options=shark_options)
|
||||
|
||||
input = torch.ones(4)
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(t):
|
||||
return 2 * t
|
||||
x = compiled(input, input)
|
||||
|
||||
|
||||
example_input = torch.rand((2, 3))
|
||||
x = foo(example_input)
|
||||
print(x)
|
||||
|
||||
input = torch.ones(3)
|
||||
|
||||
torchdynamo.reset()
|
||||
x = compiled(input, input)
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=False)
|
||||
)
|
||||
def foo(a, b):
|
||||
x = a / (a + 1)
|
||||
if b.sum() < 0:
|
||||
b = b * -1
|
||||
return x * b
|
||||
|
||||
|
||||
print(foo(torch.rand((2, 3)), -torch.rand((2, 3))))
|
||||
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def foo(a):
|
||||
for i in range(10):
|
||||
a += 1.0
|
||||
return a
|
||||
|
||||
|
||||
print(foo(torch.rand((1, 2))))
|
||||
|
||||
torchdynamo.reset()
|
||||
|
||||
|
||||
@torchdynamo.optimize(
|
||||
make_shark_compiler(use_tracing=False, device="cuda", verbose=True)
|
||||
)
|
||||
def test_unsupported_types(t, y):
|
||||
return t, 2 * y
|
||||
|
||||
|
||||
str_input = "hello"
|
||||
tensor_input = torch.randn(2)
|
||||
print(test_unsupported_types(str_input, tensor_input))
|
||||
print(x)
|
||||
|
||||
Reference in New Issue
Block a user