gelu -> quick_gelu in hlb_cifar (#3147)

89 -> 86 seconds, same eval acc
This commit is contained in:
chenyu
2024-01-16 02:03:37 -05:00
committed by GitHub
parent ec5a212b0a
commit b9d470577c

View File

@@ -49,13 +49,13 @@ class ConvGroup:
x = x.float()
x = self.norm1(x)
x = x.cast(dtypes.default_float)
x = x.gelu()
x = x.quick_gelu()
residual = x
x = self.conv2(x)
x = x.float()
x = self.norm2(x)
x = x.cast(dtypes.default_float)
x = x.gelu()
x = x.quick_gelu()
return x + residual
@@ -64,7 +64,7 @@ class SpeedyResNet:
self.whitening = W
self.net = [
nn.Conv2d(12, 32, kernel_size=1, bias=False),
lambda x: x.gelu(),
lambda x: x.quick_gelu(),
ConvGroup(32, 64),
ConvGroup(64, 256),
ConvGroup(256, 512),