This commit is contained in:
George Hotz
2025-10-11 13:36:19 +08:00
parent dd2ff2ddb9
commit b14da7f9d4
2 changed files with 17 additions and 7 deletions

View File

@@ -21,10 +21,10 @@ class Bottleneck:
class ResNet50:
def __init__(self, num_classes=1000):
self.conv1, self.bn1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False), nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 64, 3, 1) # 256 out
self.layer2 = self._make_layer(256, 128, 4, 2) # 512 out
self.layer3 = self._make_layer(512, 256, 6, 2) # 1024 out
self.layer4 = self._make_layer(1024,512, 3, 2) # 2048 out
self.layer1 = self._make_layer(64, 64, 3, 1)
self.layer2 = self._make_layer(256, 128, 4, 2)
self.layer3 = self._make_layer(512, 256, 6, 2)
self.layer4 = self._make_layer(1024,512, 3, 2)
self.fc = nn.Linear(2048, num_classes)
def _make_layer(self, in_c, mid_c, blocks, stride):
@@ -35,8 +35,8 @@ class ResNet50:
def __call__(self, x:Tensor) -> Tensor:
x = self.bn1(self.conv1(x)).relu()
x = x.max_pool2d()
x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
x = x.mean((2, 3)) # global average pool
x = x.sequential([self.layer1, self.layer2, self.layer3, self.layer4])
x = x.mean((2, 3))
return self.fc(x)
if __name__ == "__main__":

View File

@@ -352,4 +352,14 @@ def gguf_load(tensor: Tensor) -> tuple[dict, dict[str, Tensor]]:
@accept_filename
def png_load(t:Tensor) -> Tensor:
assert t[0:8].tolist() == [0x89,0x50,0x4E,0x47,0x0D,0x0A,0x1A,0x0A], "not a PNG"
f = io.BufferedReader(TensorIO(t))
assert f.read(8) == b'\x89PNG\r\n\x1a\n', "not a PNG"
while (slen:=f.read(4)):
len, typ = struct.unpack(">I", slen)[0], f.read(4)
dat = f.read(len)
if typ == b'IHDR':
width, height, depth, color_type, compression, filter_method, interlace = struct.unpack(">IIBBBBB", dat)
print(width, height, depth, color_type)
print(len, typ)
f.seek(4, 1)