mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
very minor change
This commit is contained in:
@@ -41,7 +41,6 @@ class MBConvBlock:
|
||||
x = x.conv2d(self._depthwise_conv, stride=self.strides, groups=self._depthwise_conv.shape[0])
|
||||
x = self._bn1(x).swish()
|
||||
|
||||
# has_se
|
||||
if self.has_se:
|
||||
x_squeezed = x.avg_pool2d(kernel_size=x.shape[2:4])
|
||||
x_squeezed = x_squeezed.conv2d(self._se_reduce, self._se_reduce_bias).swish()
|
||||
|
||||
@@ -83,7 +83,7 @@ class Bottleneck:
|
||||
out = self.bn1(self.conv1(x)).relu()
|
||||
out = self.bn2(self.conv2(out)).relu()
|
||||
out = self.bn3(self.conv3(out))
|
||||
out = out + self.downsample(x)
|
||||
out = out + x.sequential(self.downsample)
|
||||
out = out.relu()
|
||||
return out
|
||||
|
||||
|
||||
@@ -20,9 +20,10 @@ class ViT:
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
ce = self.cls.add(Tensor.zeros(x.shape[0],1,1))
|
||||
pe = self.patch_embed(x)
|
||||
x = self.cls.add(Tensor.zeros(pe.shape[0],1,1)).cat(pe, dim=1) + self.pos_embedding
|
||||
x = x.sequential(self.tbs)
|
||||
x = ce.cat(pe, dim=1)
|
||||
x = x.add(self.pos_embedding).sequential(self.tbs)
|
||||
x = x.layernorm().linear(*self.encoder_norm)
|
||||
return x[:, 0].linear(*self.head)
|
||||
|
||||
@@ -40,14 +41,14 @@ class ViT:
|
||||
m.embedding[0].assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
|
||||
m.embedding[1].assign(dat['embedding/bias'])
|
||||
|
||||
m.encoder_norm[0].assign(dat['Transformer/encoder_norm/scale'])
|
||||
m.encoder_norm[1].assign(dat['Transformer/encoder_norm/bias'])
|
||||
m.cls.assign(dat['cls'])
|
||||
|
||||
m.head[0].assign(dat['head/kernel'])
|
||||
m.head[1].assign(dat['head/bias'])
|
||||
|
||||
m.cls.assign(dat['cls'])
|
||||
m.pos_embedding.assign(dat['Transformer/posembed_input/pos_embedding'])
|
||||
m.encoder_norm[0].assign(dat['Transformer/encoder_norm/scale'])
|
||||
m.encoder_norm[1].assign(dat['Transformer/encoder_norm/bias'])
|
||||
|
||||
for i in range(12):
|
||||
m.tbs[i].query[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
|
||||
|
||||
Reference in New Issue
Block a user