From f5a3b33d33a8a441aa12338655b4b55d1d86ee5d Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Oct 2025 17:12:22 +0800 Subject: [PATCH] add fun with nhwc convs --- test/external/external_nhwc_conv.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 test/external/external_nhwc_conv.py diff --git a/test/external/external_nhwc_conv.py b/test/external/external_nhwc_conv.py new file mode 100644 index 0000000000..541de1154b --- /dev/null +++ b/test/external/external_nhwc_conv.py @@ -0,0 +1,38 @@ +from tinygrad import Tensor, nn, Context, GlobalCounters + +if __name__ == "__main__": + conv = nn.Conv2d(64, 128, 3) + img = Tensor.randn((1,64,128,128)) + with Context(DEBUG=0, BEAM=0): + Tensor.realize(img, conv.weight, conv.bias) + + tst = conv(img).permute(0,2,3,1).realize() + print(tst.shape) + + print("NEW") + img_perm = img.permute(0,2,3,1).contiguous() + print(img_perm.shape) + pp = img_perm.permute(0,3,1,2)._pool((3,3)).permute(0,2,3,4,5,1) + + def hwio(pp, conv): + pp = pp.unsqueeze(-1) + weight = conv.weight.permute(2,3,1,0).contiguous() + print(pp.shape, weight.shape, (pp*weight).shape) + return (pp * weight).sum([-4,-3, -2]) + + def ohwi(pp, conv): + pp = pp.unsqueeze(-4) + weight = conv.weight.permute(0,2,3,1).contiguous() + print(pp.shape, weight.shape, (pp*weight).shape) + return (pp * weight).sum([-3,-2,-1]) + + for f in [hwio, ohwi]: + GlobalCounters.reset() + print("\n**************", f.__name__, "**************") + out = f(pp, conv) + out.realize() + print(out.shape) + + with Context(DEBUG=0, BEAM=0): + err = (tst-out).square() + print(err.mean().item(), err.max().item())