Bug fix related to torch-shark trainer.

This commit is contained in:
Prashant Kumar
2022-05-27 09:18:31 +00:00
parent e5517f63f5
commit 8294dc3f20

View File

@@ -36,7 +36,8 @@ class SharkTrainer:
from_aot: bool = True,
):
self.model = model
self.input = input
# Change tuple to list.
self.input = [x for x in input]
self.dynamic = dynamic
self.from_aot = from_aot
self.jit_trace = jit_trace
@@ -55,7 +56,9 @@ class SharkTrainer:
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None):
if self.frontend in ["torch", "pytorch"]:
aot_module = MakeFxModule(self.model, self.input, training_fn)
aot_module = MakeFxModule(self.model,
tuple(self.input),
custom_inference_fn=training_fn)
aot_module.generate_graph()
# Returns the backward graph.
training_graph = aot_module.training_graph
@@ -114,9 +117,9 @@ class SharkTrainer:
def train(self, num_iters=1):
if self.frontend in ["torch", "pytorch"]:
return self._train_torch(self, num_iters)
return self._train_torch(num_iters)
elif self.frontend in ["tf", "tensorflow"]:
return self._train_tf(self, num_iters)
return self._train_tf(num_iters)
else:
print("Unknown frontend")
return