mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Bug fix related to torch-shark trainer.
This commit is contained in:
@@ -36,7 +36,8 @@ class SharkTrainer:
|
||||
from_aot: bool = True,
|
||||
):
|
||||
self.model = model
|
||||
self.input = input
|
||||
# Change tuple to list.
|
||||
self.input = [x for x in input]
|
||||
self.dynamic = dynamic
|
||||
self.from_aot = from_aot
|
||||
self.jit_trace = jit_trace
|
||||
@@ -55,7 +56,9 @@ class SharkTrainer:
|
||||
# Training function is needed in the case of torch_fn.
|
||||
def compile(self, training_fn=None):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
aot_module = MakeFxModule(self.model, self.input, training_fn)
|
||||
aot_module = MakeFxModule(self.model,
|
||||
tuple(self.input),
|
||||
custom_inference_fn=training_fn)
|
||||
aot_module.generate_graph()
|
||||
# Returns the backward graph.
|
||||
training_graph = aot_module.training_graph
|
||||
@@ -114,9 +117,9 @@ class SharkTrainer:
|
||||
|
||||
def train(self, num_iters=1):
|
||||
if self.frontend in ["torch", "pytorch"]:
|
||||
return self._train_torch(self, num_iters)
|
||||
return self._train_torch(num_iters)
|
||||
elif self.frontend in ["tf", "tensorflow"]:
|
||||
return self._train_tf(self, num_iters)
|
||||
return self._train_tf(num_iters)
|
||||
else:
|
||||
print("Unknown frontend")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user