Files
SHARK-Studio/amdshark/examples/amdshark_training/neural_net_training.py
pdhirajkumarprasad 6d80b43b6b Migration to AMDShark
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-20 12:46:36 +05:30

45 lines
1.1 KiB
Python

import torch
from torch.nn.utils import _stateless
from amdshark.amdshark_trainer import AMDSharkTrainer
class Foo(torch.nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.l1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.l2 = torch.nn.Linear(16, 2)
def forward(self, x):
out = self.l1(x)
out = self.relu(out)
out = self.l2(out)
return out
mod = Foo()
inp = (torch.randn(10, 10),)
def get_sorted_params(named_params):
return [i[1] for i in sorted(named_params.items())]
def forward(params, buffers, args):
params_and_buffers = {**params, **buffers}
_stateless.functional_call(
mod, params_and_buffers, args, {}
).sum().backward()
optim = torch.optim.SGD(get_sorted_params(params), lr=0.01)
optim.step()
return params, buffers
# fx_graph = forward(dict(mod.named_parameters()), dict(mod.named_buffers()), inp)
amdshark_module = AMDSharkTrainer(mod, inp)
# Pass the training function in case of torch
amdshark_module.compile(training_fn=forward)
amdshark_module.train(num_iters=10)