mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-09 22:07:55 -05:00
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
import torch
|
|
from torch.nn.utils import _stateless
|
|
from amdshark.amdshark_trainer import AMDSharkTrainer
|
|
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.l1 = torch.nn.Linear(10, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
self.l2 = torch.nn.Linear(16, 2)
|
|
|
|
def forward(self, x):
|
|
out = self.l1(x)
|
|
out = self.relu(out)
|
|
out = self.l2(out)
|
|
return out
|
|
|
|
|
|
mod = Foo()
|
|
inp = (torch.randn(10, 10),)
|
|
|
|
|
|
def get_sorted_params(named_params):
|
|
return [i[1] for i in sorted(named_params.items())]
|
|
|
|
|
|
def forward(params, buffers, args):
|
|
params_and_buffers = {**params, **buffers}
|
|
_stateless.functional_call(
|
|
mod, params_and_buffers, args, {}
|
|
).sum().backward()
|
|
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
|
|
optim.step()
|
|
return params, buffers
|
|
|
|
|
|
# fx_graph = forward(dict(mod.named_parameters()), dict(mod.named_buffers()), inp)
|
|
|
|
amdshark_module = AMDSharkTrainer(mod, inp)
|
|
# Pass the training function in case of torch
|
|
amdshark_module.compile(training_fn=forward)
|
|
|
|
amdshark_module.train(num_iters=10)
|