mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
gelu -> quick_gelu in hlb_cifar (#3147)
89 -> 86 seconds, same eval acc
This commit is contained in:
@@ -49,13 +49,13 @@ class ConvGroup:
|
||||
x = x.float()
|
||||
x = self.norm1(x)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
x = x.quick_gelu()
|
||||
residual = x
|
||||
x = self.conv2(x)
|
||||
x = x.float()
|
||||
x = self.norm2(x)
|
||||
x = x.cast(dtypes.default_float)
|
||||
x = x.gelu()
|
||||
x = x.quick_gelu()
|
||||
|
||||
return x + residual
|
||||
|
||||
@@ -64,7 +64,7 @@ class SpeedyResNet:
|
||||
self.whitening = W
|
||||
self.net = [
|
||||
nn.Conv2d(12, 32, kernel_size=1, bias=False),
|
||||
lambda x: x.gelu(),
|
||||
lambda x: x.quick_gelu(),
|
||||
ConvGroup(32, 64),
|
||||
ConvGroup(64, 256),
|
||||
ConvGroup(256, 512),
|
||||
|
||||
Reference in New Issue
Block a user