mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
hotfix: resnet to obj.device
This commit is contained in:
@@ -144,7 +144,7 @@ class ResNet:
|
||||
continue # Skip FC if transfer learning
|
||||
|
||||
if 'bn' not in k and 'downsample' not in k: assert obj.shape == dat.shape, (k, obj.shape, dat.shape)
|
||||
obj.assign(dat.to(None).reshape(obj.shape))
|
||||
obj.assign(dat.to(obj.device).reshape(obj.shape))
|
||||
|
||||
ResNet18 = lambda num_classes=1000: ResNet(18, num_classes=num_classes)
|
||||
ResNet34 = lambda num_classes=1000: ResNet(34, num_classes=num_classes)
|
||||
|
||||
Reference in New Issue
Block a user