From 5ba611787d99efa59828e6d21566f7832b73b5a5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 15 May 2024 10:50:25 -0700 Subject: [PATCH] move image into tensor.py. delete features (#4603) * move image into tensor.py * change setup.py * openpilot tests need pythonpath now --- .github/workflows/test.yml | 10 +- examples/hlb_cifar10.py | 2 +- setup.py | 5 +- ...xternal_benchmark_multitensor_allreduce.py | 2 +- test/test_image_dtype.py | 2 +- test/test_multitensor.py | 2 +- tinygrad/codegen/linearizer.py | 7 +- tinygrad/features/image.py | 93 ------------------- tinygrad/lazy.py | 4 +- tinygrad/{features => }/multi.py | 2 +- tinygrad/nn/state.py | 2 +- tinygrad/tensor.py | 88 +++++++++++++++++- 12 files changed, 106 insertions(+), 113 deletions(-) delete mode 100644 tinygrad/features/image.py rename tinygrad/{features => }/multi.py (98%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 87b4dc370a..d95571760d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -167,17 +167,17 @@ jobs: - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot model compile and size run: | - DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py - #python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot model correctness (float32) - run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot alt model correctness (float32) - run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx + run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'openpilot' }} name: Test openpilot fastvits model correctness (float32) - run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx + run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx - if: ${{ matrix.task == 'onnx' }} name: Test ONNX (GPU) run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 759d042ec9..fffbee0bcb 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod -from tinygrad.features.multi import MultiLazyBuffer +from tinygrad.multi import MultiLazyBuffer BS, STEPS = getenv("BS", 512), getenv("STEPS", 1000) EVAL_BS = getenv("EVAL_BS", BS) diff --git a/setup.py b/setup.py index b54e19bcf0..368121e560 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup(name='tinygrad', long_description=long_description, long_description_content_type='text/markdown', packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.engine', - 'tinygrad.runtime', 'tinygrad.runtime.driver', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.features'], + 'tinygrad.runtime', 'tinygrad.runtime.driver', 'tinygrad.runtime.graph', 'tinygrad.shape'], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License" @@ -59,7 +59,8 @@ setup(name='tinygrad', "mkdocs-material", "mkdocstrings[python]", "markdown-callouts", - "markdown-exec[ansi]" + "markdown-exec[ansi]", + "black" ], 'testing_tf': [ "tensorflow==2.15.1", diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index 3340af50be..99a43d740a 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -2,7 +2,7 @@ import time from tinygrad import Tensor, Device, GlobalCounters, TinyJit from tinygrad.lazy import LazyBuffer from tinygrad.ops import ReduceOps -from tinygrad.features.multi import MultiLazyBuffer, all_reduce +from tinygrad.multi import MultiLazyBuffer, all_reduce from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule from tinygrad.helpers import getenv, Context, RING diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 4f9dccedaa..75ecdec20d 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad import Device, dtypes, Tensor, Variable from tinygrad.dtype import ImageDType -from tinygrad.features.image import to_image_idx +from tinygrad.codegen.linearizer import to_image_idx @unittest.skipIf(Device.DEFAULT != "GPU", "only images on GPU") class TestImageDType(unittest.TestCase): diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 6cc9b99e7a..cbfde10501 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -6,7 +6,7 @@ from tinygrad.helpers import CI, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner -from tinygrad.features.multi import all_reduce, MultiLazyBuffer +from tinygrad.multi import all_reduce, MultiLazyBuffer from random import randint import numpy as np from hypothesis import given, strategies as strat, settings diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index a132e08f42..1da2677912 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -9,7 +9,6 @@ from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Con from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node from tinygrad.codegen.kernel import LocalBuffer, Kernel -from tinygrad.features.image import to_image_idx from tinygrad.renderer import Program from tinygrad.codegen.uops import UOps, UOp, UOpGraph @@ -32,6 +31,12 @@ def expand_idxs(nodes:Sequence[Node]) -> Tuple[Union[Variable, NumNode], ...]: def iter_idxs(idxs:Tuple[Union[Variable, NumNode], ...]) -> Iterator[Tuple[int,...]]: yield from (x[::-1] for x in itertools.product(*[[x for x in range(v.min, v.max + 1)] for v in idxs[::-1]])) +def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]: + idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1])) + # TODO: bring back the valid removal logic (correct!) + if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid) + return (idx, idy), valid + # expand a Node into List[Node] that enumerates the underlying Variables from min to max # expand increments earlier variables faster than later variables (as specified in the argument) @functools.lru_cache(maxsize=None) diff --git a/tinygrad/features/image.py b/tinygrad/features/image.py deleted file mode 100644 index 1d071bfc59..0000000000 --- a/tinygrad/features/image.py +++ /dev/null @@ -1,93 +0,0 @@ -from typing import Tuple -from tinygrad.helpers import prod, IMAGE, getenv, DEBUG -from tinygrad.dtype import dtypes - -# *** image Tensor function replacements *** - -def image_dot(self, w, acc_dtype=None): - # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) - n1, n2 = len(self.shape), len(w.shape) - assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 - bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1] - out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, ) - - # NOTE: with NHWC we can remove the transposes - # bs x groups*cin x H x W - cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1)) - # groups*cout x cin x H, W - cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) - return image_conv2d(cx, cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) - -def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None): - base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef - - (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape - x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) - - # hack for non multiples of 4 on cin - if cin % 4 != 0 and not (cin == 1 and groups%4 == 0): - x = x.reshape(bs, groups, cin, iy, ix) # do this always? - added_input_channels = 4 - (cin % 4) - w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim))) - x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim))) - cin = cin + added_input_channels - x = x.reshape(bs, groups*cin, iy, ix) - - # hack for non multiples of 4 on rcout - added_output_channels = 0 - if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): - added_output_channels = 4 - (rcout % 4) - rcout += added_output_channels - cout = groups * rcout - w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim))) - - # packed (note: flipping bs and iy would make the auto-padding work) - x = x.permute(0,2,3,1) - cin_last = iy == 1 and ix == 1 - if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) - elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) - else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) - - # contiguous creates the image, and early realize static weights (TODO: test for the static weight) - if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) - x, w = x.contiguous(), w.contiguous() - - # expand out - rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 - cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1] - x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) - if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) - else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) - - # padding - padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) - x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None)) - - # prepare input - x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) - x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) - - # prepare weights - w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) - - # the conv! - ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype) - - # undo hack for non multiples of 4 on C.rcout - if added_output_channels != 0: - ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] - cout = groups * (rcout - added_output_channels) - - # NCHW output - ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) - return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) - -# *** images have weird indexing requirements *** - -from tinygrad.shape.symbolic import Node -def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]: - idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1])) - # TODO: bring back the valid removal logic (correct!) - if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy, valid) - return (idx, idy), valid diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 431fe62efc..b02d371103 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -153,8 +153,8 @@ class LazyBuffer: if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG: return self.base.srcs[0] if op in BinaryOps: x, y = self, in_srcs[0] if op is BinaryOps.ADD: - if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x - if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y + if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x # pylint: disable=possibly-used-before-assignment + if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y # pylint: disable=possibly-used-before-assignment if op is BinaryOps.SUB and y.is_unrealized_unmasked_const() and y.base.arg == 0: return x if op is BinaryOps.MUL: if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1): diff --git a/tinygrad/features/multi.py b/tinygrad/multi.py similarity index 98% rename from tinygrad/features/multi.py rename to tinygrad/multi.py index 24eb9cbe76..663d461673 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/multi.py @@ -85,7 +85,7 @@ class MultiLazyBuffer: return functools.reduce(lambda x,y: x.e(BinaryOps.ADD, y), llbs) # passthroughs - def is_realized(self) -> bool: return all([lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True]) + def is_realized(self) -> bool: return all(lb.base.realized is not None for lb, r in zip(self.lbs, self.real) if r is True) def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) def const(self, val:ConstType) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real) def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 28b9fe5d2e..57148768a6 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,7 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters from tinygrad.shape.view import strides_for_shape -from tinygrad.features.multi import MultiLazyBuffer +from tinygrad.multi import MultiLazyBuffer safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dtypes.int16, "U16":dtypes.uint16, "I32":dtypes.int, "U32":dtypes.uint, "I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64} diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 682feac9e3..ab0b784007 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -10,7 +10,7 @@ from tinygrad.dtype import DType, dtypes, ImageDType, ConstType, least_upper_flo from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, argsort, IMAGE, DEBUG, WINO, THREEFRY from tinygrad.helpers import getenv from tinygrad.lazy import LazyBuffer -from tinygrad.features.multi import MultiLazyBuffer +from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Buffer, BufferOptions from tinygrad.device import Device @@ -1374,14 +1374,94 @@ class Tensor: def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) def size(self, dim=None) -> Union[sint, Tuple[sint, ...]]: return self.shape if dim is None else self.shape[dim] + # *** image Tensor function replacements *** + + def image_dot(self, w:Tensor, acc_dtype=None): + # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) + n1, n2 = len(self.shape), len(w.shape) + assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" + assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 + bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1] + out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, ) + + # NOTE: with NHWC we can remove the transposes + # bs x groups*cin x H x W + cx = self.transpose(self.ndim-1, self.ndim-2).reshape((bs//groups, groups*cin, -1, 1)) + # groups*cout x cin x H, W + cw = w.transpose(w.ndim-1, w.ndim-2).reshape((groups*cout, cin, 1, 1)) + return cx.image_conv2d(cw, groups=groups, acc_dtype=acc_dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) + + def image_conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype=None): + base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef + + (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape + x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) + + # hack for non multiples of 4 on cin + if cin % 4 != 0 and not (cin == 1 and groups%4 == 0): + x = x.reshape(bs, groups, cin, iy, ix) # do this always? + added_input_channels = 4 - (cin % 4) + w = w.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(w.ndim))) + x = x.pad(tuple((0, added_input_channels) if i == 2 else None for i in range(x.ndim))) + cin = cin + added_input_channels + x = x.reshape(bs, groups*cin, iy, ix) + + # hack for non multiples of 4 on rcout + added_output_channels = 0 + if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): + added_output_channels = 4 - (rcout % 4) + rcout += added_output_channels + cout = groups * rcout + w = w.pad(tuple((0, added_output_channels) if i == 1 else None for i in range(w.ndim))) + + # packed (note: flipping bs and iy would make the auto-padding work) + x = x.permute(0,2,3,1) + cin_last = iy == 1 and ix == 1 + if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) + elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) + else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) + + # contiguous creates the image, and early realize static weights (TODO: test for the static weight) + if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) + x, w = x.contiguous(), w.contiguous() + + # expand out + rcin_hi, rcin_lo = cin//4 if cin >= 4 else 1, 4 if cin >= 4 else 1 + cout_expand = [groups//4 if cin == 1 else groups, 4 if cin == 1 else 1, rcout//4 if rcout >= 4 else 1, 4 if rcout >= 4 else 1] + x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) + if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) + else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) + + # padding + padding_ = [padding]*4 if isinstance(padding, int) else (padding if len(padding) == 4 else [padding[1], padding[1], padding[0], padding[0]]) + x = x.slice((None, (-padding_[2], x.shape[1]+padding_[3]), (-padding_[0], x.shape[2]+padding_[1]), None, None, None)) + + # prepare input + x = x.permute(0,3,4,5,1,2)._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) + x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) + + # prepare weights + w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *cout_expand, rcin_hi, rcin_lo, H, W)) + + # the conv! + ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), acc_dtype=acc_dtype) + + # undo hack for non multiples of 4 on C.rcout + if added_output_channels != 0: + ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] + cout = groups * (rcout - added_output_channels) + + # NCHW output + ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2) + return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1)) + # register functions to move between devices for device in Device._devices: setattr(Tensor, f"{device.lower()}", functools.partialmethod(Tensor.to, device)) if IMAGE: # if IMAGE>0 we install these replacement functions in Tensor (hack!) - from tinygrad.features.image import image_conv2d, image_dot - setattr(Tensor, "conv2d", image_conv2d) - setattr(Tensor, "dot", image_dot) + setattr(Tensor, "conv2d", Tensor.image_conv2d) + setattr(Tensor, "dot", Tensor.image_dot) # TODO: eventually remove this def custom_random(out:Buffer):