hotfix: resnet to obj.device

This commit is contained in:
George Hotz
2024-09-06 13:06:02 +08:00
parent 9d72119a0c
commit 8f6d0485e7

View File

@@ -144,7 +144,7 @@ class ResNet:
continue # Skip FC if transfer learning
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
obj.assign(dat.to(None).reshape(obj.shape))
obj.assign(dat.to(obj.device).reshape(obj.shape))
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)