mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
work
This commit is contained in:
@@ -21,10 +21,10 @@ class Bottleneck:
|
||||
class ResNet50:
|
||||
def __init__(self, num_classes=1000):
|
||||
self.conv1, self.bn1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False), nn.BatchNorm2d(64)
|
||||
self.layer1 = self._make_layer(64, 64, 3, 1) # 256 out
|
||||
self.layer2 = self._make_layer(256, 128, 4, 2) # 512 out
|
||||
self.layer3 = self._make_layer(512, 256, 6, 2) # 1024 out
|
||||
self.layer4 = self._make_layer(1024,512, 3, 2) # 2048 out
|
||||
self.layer1 = self._make_layer(64, 64, 3, 1)
|
||||
self.layer2 = self._make_layer(256, 128, 4, 2)
|
||||
self.layer3 = self._make_layer(512, 256, 6, 2)
|
||||
self.layer4 = self._make_layer(1024,512, 3, 2)
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
def _make_layer(self, in_c, mid_c, blocks, stride):
|
||||
@@ -35,8 +35,8 @@ class ResNet50:
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = self.bn1(self.conv1(x)).relu()
|
||||
x = x.max_pool2d()
|
||||
x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
|
||||
x = x.mean((2, 3)) # global average pool
|
||||
x = x.sequential([self.layer1, self.layer2, self.layer3, self.layer4])
|
||||
x = x.mean((2, 3))
|
||||
return self.fc(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -352,4 +352,14 @@ def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
|
||||
|
||||
@accept_filename
|
||||
def png_load(t:Tensor) -> Tensor:
|
||||
assert t[0:8].tolist() == [0x89,0x50,0x4E,0x47,0x0D,0x0A,0x1A,0x0A], "not a PNG"
|
||||
f = io.BufferedReader(TensorIO(t))
|
||||
assert f.read(8) == b'\x89PNG\r\n\x1a\n', "not a PNG"
|
||||
while (slen:=f.read(4)):
|
||||
len, typ = struct.unpack(">I", slen)[0], f.read(4)
|
||||
dat = f.read(len)
|
||||
if typ == b'IHDR':
|
||||
width, height, depth, color_type, compression, filter_method, interlace = struct.unpack(">IIBBBBB", dat)
|
||||
print(width, height, depth, color_type)
|
||||
|
||||
print(len, typ)
|
||||
f.seek(4, 1)
|
||||
|
||||
Reference in New Issue
Block a user