IMAGE=1 creates "dynamic" images (#13769)

* remove image from BufferSpec

* cl tiny_gemm (64) works

* mypy

* padding

* openpilot CL

* reshape properly

* remove extra qcom checks

* pad output

* mypy

* update compile test

* move undo

* TestImageCopy valid images

* TestImageRealization valid images

* TestImageDType valid images

* cleanups

* test_renderer_failures

* ruff

* mypy

* simplify ops_qcom

* bump step time

* Revert "bump step time"

This reverts commit 75a037c7d0.

* "dynamic textures" are optional

* a start

* IMAGE=1 works, no FLOAT16

* fast but wrong

* mypy

* some fixes

* better

* works

* refactor

* oops
This commit is contained in:
Christopher Milan
2026-01-02 13:22:39 -08:00
committed by GitHub
parent 61dc70f1a8
commit 9dc524536f
12 changed files with 133 additions and 74 deletions

View File

@@ -394,7 +394,7 @@ jobs:
llvm: 'true' llvm: 'true'
- name: Test openpilot model kernel count and gate usage - name: Test openpilot model kernel count and gate usage
run: | run: |
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 ALLOWED_KERNEL_COUNT=125 ALLOWED_READ_IMAGE=1389 ALLOWED_GATED_READ_IMAGE=101 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp16 - name: Test openpilot CL compile fp16
run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
- name: Test openpilot CL compile fp32 (test correctness) - name: Test openpilot CL compile fp32 (test correctness)

View File

@@ -10,25 +10,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageCopy(unittest.TestCase): class TestImageCopy(unittest.TestCase):
def test_image_copyout_1x1(self, img_type=dtypes.imagef): def test_image_copyout_1x8(self, img_type=dtypes.imagef):
it = Tensor.arange(4).cast(img_type((1,1,4))).realize() it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
buf = it.uop.buffer buf = it.uop.buffer
out = buf.as_buffer() out = buf.as_buffer()
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4)) np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32))
@unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half") @unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half")
def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh) def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh)
def test_image_numpy_1x1(self, img_type=dtypes.imagef): def test_image_numpy_1x8(self, img_type=dtypes.imagef):
it = Tensor.arange(4).cast(img_type((1,1,4))).realize() it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
np.testing.assert_equal(it.numpy(), np.arange(4)) np.testing.assert_equal(it.numpy(), np.arange(32))
def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh) def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh)
def test_image_copyout_2x3(self): def test_image_copyout_2x4(self):
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize() it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize()
buf = it.uop.buffer buf = it.uop.buffer
out = buf.as_buffer() out = buf.as_buffer()
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4)) np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4))
def test_image_roundtrip(self): def test_image_roundtrip(self):
sz = (4,2,4) sz = (4,2,4)
@@ -105,9 +105,9 @@ class TestImageDType(unittest.TestCase):
__validate(dtypes.imagef((1, 1)), 0x40) __validate(dtypes.imagef((1, 1)), 0x40)
def test_image_and_back(self): def test_image_and_back(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
tst = data.numpy() tst = data.numpy()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
assert isinstance(it.uop.base.realized.dtype, ImageDType) assert isinstance(it.uop.base.realized.dtype, ImageDType)
np.testing.assert_equal(tst, it.numpy()) np.testing.assert_equal(tst, it.numpy())
@@ -127,13 +127,13 @@ class TestImageDType(unittest.TestCase):
np.testing.assert_equal(tst, it.numpy()) np.testing.assert_equal(tst, it.numpy())
def test_shrink_load_float(self): def test_shrink_load_float(self):
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize() it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
imgv = it.numpy() imgv = it.numpy()
np.testing.assert_equal(imgv[0:2], it[0:2].numpy()) np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
def test_mul_stays_image(self): def test_mul_stays_image(self):
# NOTE: contiguous is needed otherwise this folds # NOTE: contiguous is needed otherwise this folds
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize() it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).contiguous().realize()
out = (it*2).realize() out = (it*2).realize()
assert isinstance(out.uop.base.realized.dtype, ImageDType) assert isinstance(out.uop.base.realized.dtype, ImageDType)
@@ -143,7 +143,7 @@ class TestImageDType(unittest.TestCase):
np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6) np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6)
def test_shrink_max(self): def test_shrink_max(self):
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize() it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
imgv = it.numpy() imgv = it.numpy()
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy()) np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
@@ -162,19 +162,19 @@ class TestImageDType(unittest.TestCase):
assert it.uop.base.realized._buf == b1 assert it.uop.base.realized._buf == b1
def test_no_lru_alloc(self): def test_no_lru_alloc(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
b1 = it.uop.base.realized._buf b1 = it.uop.base.realized._buf
del it del it
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize() it = data.reshape(9,32,4).pad_to(10, None, None).cast(dtypes.imagef((10,32,4))).contiguous().realize()
assert it.uop.base.realized._buf != b1 assert it.uop.base.realized._buf != b1
def test_no_lru_alloc_dtype(self): def test_no_lru_alloc_dtype(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
b1 = it.uop.base.realized._buf b1 = it.uop.base.realized._buf
del it del it
it = data.cast(dtypes.imageh((9,27,4))).realize() it = data.cast(dtypes.imageh((9,32,4))).realize()
assert it.uop.base.realized._buf != b1 assert it.uop.base.realized._buf != b1
# issue caused by: don't realize image to image casts. this is part of a larger problem # issue caused by: don't realize image to image casts. this is part of a larger problem
@@ -202,36 +202,36 @@ class TestImageDType(unittest.TestCase):
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageRealization(unittest.TestCase): class TestImageRealization(unittest.TestCase):
def test_image_dtype_expand(self): def test_image_dtype_expand(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize() it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous().realize()
self.assertEqual(it_expanded.dtype, dtypes.float32) self.assertEqual(it_expanded.dtype, dtypes.float32)
def test_image_dtype_expand_and_back(self): def test_image_dtype_expand_and_back(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)) it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4))
it2 = it_expanded.sum(3).realize() it2 = it_expanded.sum(3).realize()
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4))) self.assertEqual(it2.dtype, dtypes.imagef((9,32,4)))
def test_image_alu_children(self): def test_image_alu_children(self):
data = Tensor.randn(9*27*4).realize() data = Tensor.randn(9*32*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous() it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous()
alu1 = it_expanded+1 alu1 = it_expanded+1
alu2 = it_expanded.sum(3) alu2 = it_expanded.sum(3)
it_expanded.realize() it_expanded.realize()
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image # NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4))) self.assertEqual(alu1.dtype, dtypes.imagef((9,32,4)))
alu1.realize() alu1.realize()
self.assertEqual(alu1.dtype, dtypes.float32) self.assertEqual(alu1.dtype, dtypes.float32)
# alu2 is back in image because it fits the dtype again # alu2 is back in image because it fits the dtype again
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
alu2.realize() alu2.realize()
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -15,7 +15,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
ReduceContext, correct_load_store, pm_render, pm_add_loads ReduceContext, correct_load_store, pm_render, pm_add_loads
from tinygrad.codegen.opt.postrange import apply_opts from tinygrad.codegen.opt.postrange import apply_opts, make_images
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
@@ -53,6 +53,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
# split store range (only on CPU for now) # split store range (only on CPU for now)
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges") sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
# create image buffers
sink = make_images(sink, ren)
# do postrange optimization, BEAM or hand_coded_optimizations # do postrange optimization, BEAM or hand_coded_optimizations
sink = apply_opts(sink, ren) sink = apply_opts(sink, ren)
@@ -133,7 +136,7 @@ def do_linearize(prg:UOp, sink:UOp) -> UOp:
def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
src = ctx.render(list(lin.src)) src = ctx.render(list(lin.src))
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),), arg=ctx.aux(list(lin.src)) if ctx.has_aux else prg.arg)
def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None: def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
if ctx.compiler is None: return None if ctx.compiler is None: return None

View File

@@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import ALLOW_TF32, count from tinygrad.helpers import IMAGE, ALLOW_TF32, count
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer from tinygrad.renderer import Renderer
@@ -349,3 +349,24 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice): if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice):
k = hand_coded_optimizations(k) k = hand_coded_optimizations(k)
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
# create image buffers
def make_images(ast:UOp, ren:Renderer) -> UOp:
if IMAGE == 1 and ren.device in {"QCOM", "CL"}:
dg_types: dict = {}
def make_image(ctx, dg):
if (dt:=dg.dtype).base is dtypes.float and not isinstance(dt, ImageDType) and dt.size < 65536 and dt.nbytes() % 64 == 0:
ctx[dg.arg] = dt
return dg.replace(dtype=dtypes.imagef((1, dt.size // 4, 4), dt.nbytes()))
ast = graph_rewrite(ast, PatternMatcher([(UPat(Ops.DEFINE_GLOBAL, name="dg"), make_image)]), ctx=dg_types, name="create image buffers")
# undo unfoldable stores
def undo_image_store(ctx, st, idx, dg):
if dg.arg in ctx and not any(c.op is Ops.RANGE and (c.vmax+1)%4 == 0 for c in idx.src[1].get_idx().split_uop(Ops.ADD)):
return st.replace(src=(idx.replace(src=(dg.replace(dtype=ctx[dg.arg]),)+idx.src[1:]),)+st.src[1:])
ast = graph_rewrite(ast, PatternMatcher([
(UPat(Ops.DEFINE_GLOBAL, name="dg").index(UPat(), name="idx").store(UPat(), name="st"), undo_image_store)
]), ctx=dg_types, name="remove unfoldable image stores")
return ast

View File

@@ -94,12 +94,14 @@ class PtrDType(DType):
@dataclass(frozen=True, eq=False) @dataclass(frozen=True, eq=False)
class ImageDType(PtrDType): class ImageDType(PtrDType):
shape: tuple[int, ...] = () # shape of the Image shape: tuple[int, ...] = () # shape of the Image
_pitch: int = -1
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
assert addrspace == AddrSpace.GLOBAL, "images can't be local" assert addrspace == AddrSpace.GLOBAL, "images can't be local"
return self return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '') def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@property @property
def pitch(self): def pitch(self):
if self._pitch != -1: return self._pitch
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize)) imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6 pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2)) align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
@@ -181,9 +183,9 @@ class dtypes:
# NOTE: these are image dtypes # NOTE: these are image dtypes
@staticmethod @staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
@staticmethod @staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
default_float: ClassVar[DType] = float32 default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32 default_int: ClassVar[DType] = int32

View File

@@ -45,7 +45,7 @@ class CompiledRunner(Runner):
self.p:ProgramSpec = p self.p:ProgramSpec = p
assert self.p.lib is not None assert self.p.lib is not None
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib) if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib)
self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg self._prg = Device[p.device].runtime(p.function_name, self.p.lib, *p.aux) if prg is None else prg
super().__init__(p.name, p.device, p.estimates) super().__init__(p.name, p.device, p.estimates)
def __reduce__(self): return self.__class__, (self.p,) def __reduce__(self): return self.__class__, (self.p,)

View File

@@ -168,6 +168,7 @@ class ContextVar:
ContextVar._cache[key] = self ContextVar._cache[key] = self
self.value, self.key = getenv(key, default_value), key self.value, self.key = getenv(key, default_value), key
def __bool__(self): return bool(self.value) def __bool__(self): return bool(self.value)
def __eq__(self, x): return self.value == x
def __ge__(self, x): return self.value >= x def __ge__(self, x): return self.value >= x
def __gt__(self, x): return self.value > x def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x def __lt__(self, x): return self.value < x

View File

@@ -66,6 +66,7 @@ class ProgramSpec:
ast:UOp # save the base ast (this is method cache key) ast:UOp # save the base ast (this is method cache key)
uops:list[UOp]|None=None uops:list[UOp]|None=None
lib:bytes|None=None lib:bytes|None=None
aux:list=field(default_factory=list)
# filled in from uops (via from_uop) # filled in from uops (via from_uop)
global_size:list[int]=field(default_factory=lambda: [1,1,1]) global_size:list[int]=field(default_factory=lambda: [1,1,1])
@@ -123,7 +124,7 @@ class ProgramSpec:
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint # TODO: this cast is wrong, u.src[0].ssimplify() can be sint
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size, return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, list(prg.arg) if prg.arg else [], global_size, local_size,
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins))) sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
class Renderer: class Renderer:
@@ -134,6 +135,7 @@ class Renderer:
has_local: bool = True has_local: bool = True
has_threads: bool = False has_threads: bool = False
has_shared: bool = True has_shared: bool = True
has_aux: bool = False # additional program info, eg. image shapes
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
@@ -146,3 +148,4 @@ class Renderer:
def __reduce__(self): return self.__class__, () def __reduce__(self): return self.__class__, ()
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")

View File

@@ -286,6 +286,7 @@ class ClangJITRenderer(ClangRenderer):
class OpenCLRenderer(CStyleLanguage): class OpenCLRenderer(CStyleLanguage):
device = "CL" device = "CL"
has_aux = True
# language options # language options
kernel_typedef = "__kernel void" kernel_typedef = "__kernel void"
@@ -314,6 +315,8 @@ class OpenCLRenderer(CStyleLanguage):
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or [])) if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
return super().render_kernel(function_name, kernel, bufs, uops, prefix) return super().render_kernel(function_name, kernel, bufs, uops, prefix)
def aux(self, uops:list[UOp]): return (tuple(u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL),)
class IntelRenderer(OpenCLRenderer): class IntelRenderer(OpenCLRenderer):
device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void" device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
tensor_cores = tc.intel tensor_cores = tc.intel

View File

@@ -5,6 +5,7 @@ from tinygrad.runtime.autogen import opencl as cl
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet
from tinygrad.dtype import ImageDType
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 OSX_TIMING_RATIO = (125/3) if OSX else 1.0
@@ -33,8 +34,8 @@ class CLCompiler(Compiler):
return bytes(binary) return bytes(binary)
class CLProgram: class CLProgram:
def __init__(self, device:CLDevice, name:str, lib:bytes): def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]):
self.dev, self.name, self.lib = device, name, lib self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)), self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)),
to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(), to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(),
errcode_ret := ctypes.c_int32()), errcode_ret) errcode_ret := ctypes.c_int32()), errcode_ret)
@@ -50,7 +51,12 @@ class CLProgram:
def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None, def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None,
vals:tuple[int, ...]=(), wait=False) -> float|None: vals:tuple[int, ...]=(), wait=False) -> float|None:
for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) for i,(b,_) in enumerate(bufs):
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b)
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))) for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
event = cl.cl_event() if wait else None event = cl.cl_event() if wait else None
@@ -66,27 +72,15 @@ class CLProgram:
class CLAllocator(LRUAllocator['CLDevice']): class CLAllocator(LRUAllocator['CLDevice']):
def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]: def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]:
if options.image is not None:
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options) return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
@suppress_finalizing @suppress_finalizing
def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0])) def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview): def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview):
if dest[1].image is not None: if mv_address(src) % 16: src = memoryview(bytearray(src))
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0), check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
else:
if mv_address(src) % 16: src = memoryview(bytearray(src))
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]): def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]):
if src[1].image is not None: check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
else:
check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
self.dev.synchronize() self.dev.synchronize()
class CLDevice(Compiled): class CLDevice(Compiled):

View File

@@ -11,6 +11,7 @@ from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.renderer.nir import IR3Renderer from tinygrad.renderer.nir import IR3Renderer
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC
from tinygrad.dtype import ImageDType
from tinygrad.runtime.support.system import System from tinygrad.runtime.support.system import System
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
@@ -192,8 +193,9 @@ class QCOMArgsState(HCQArgsState):
super().__init__(buf, prg, bufs, vals=vals) super().__init__(buf, prg, bufs, vals=vals)
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size) ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None] ubos = [b for i,b in enumerate(bufs) if not isinstance(prg.buf_dtypes[i], ImageDType)]
ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:]) uavs = [(i,b) for i,b in enumerate(bufs) if isinstance(prg.buf_dtypes[i], ImageDType)]
ibos, texs = uavs[:prg.ibo_cnt], uavs[prg.ibo_cnt:]
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little') for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
@@ -205,19 +207,19 @@ class QCOMArgsState(HCQArgsState):
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)]) for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)])
def _tex(b, ibo=False): def _tex(b, ibo=False):
fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT fmt = mesa.FMT6_32_32_32_32_FLOAT if (img:=b[1].image or prg.buf_dtypes[b[0]]).itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt), return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt),
qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]), qreg.a6xx_tex_const_1(width=img.shape[1], height=img.shape[0]),
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr), qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=img.pitch, pitchalign=ctz(img.pitch)-6), 0, *data64_le(b[1].va_addr),
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0] qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0]
self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off) self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off)
self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off) self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off)
class QCOMProgram(HCQProgram): class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes): def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]):
self.dev: QCOMDevice = dev self.dev: QCOMDevice = dev
self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer) self.buf_dtypes, self.name, self.lib, self.NIR = buf_dtypes, name, lib, isinstance(dev.renderer, IR3Renderer)
if self.NIR: if self.NIR:
from tinygrad.runtime.support.compiler_mesa import IR3Compiler from tinygrad.runtime.support.compiler_mesa import IR3Compiler
@@ -313,7 +315,7 @@ class QCOMTextureInfo:
class QCOMAllocator(HCQAllocatorBase): class QCOMAllocator(HCQAllocatorBase):
def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer: def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer:
# Recalculate real size for texture # Recalculate real size for texture
if opts.image is not None: size = opts.image.pitch* opts.image.shape[0] if opts.image is not None: size = opts.image.pitch * opts.image.shape[0]
return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image) return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image)
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0): def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):

View File

@@ -3885,7 +3885,7 @@ class Tensor(OpMixin):
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4)
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
@@ -3898,6 +3898,20 @@ class Tensor(OpMixin):
w = w.pad_to(None, None, cin, None, None) w = w.pad_to(None, None, cin, None, None)
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix) x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
# hacks for pitch alignment
assert isinstance(ix, int) and isinstance(H, int)
added_width = 0
if (ix*groups*cin) % (64 // dtsz):
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
ix = ix + added_width
x = x.pad_to(None, None, None, ix)
added_weight = 0
if (H*W*cin) % (64 // dtsz):
added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H
H = H + added_weight
w = w.pad_to(None, None, None, H, None)
# hack for non multiples of 4 on rcout # hack for non multiples of 4 on rcout
added_output_channels = 0 added_output_channels = 0
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
@@ -3917,13 +3931,18 @@ class Tensor(OpMixin):
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
x, w = x.contiguous(), w.contiguous() x, w = x.contiguous(), w.contiguous()
if added_weight: w, H = w[:, :-added_weight, ...], H - added_weight
# expand out # expand out
rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1) rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1)
group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1) group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1)
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo)
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
# undo pitch alignment hack
if added_width: x = x[:, :, :-added_width, ...]
# prepare input # prepare input
x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W) x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W)
@@ -3931,9 +3950,20 @@ class Tensor(OpMixin):
# prepare weights # prepare weights
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W)) w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W))
added_ox = 0
assert isinstance(ox, int) and isinstance(cout, int)
if (ox * cout) % (64 // dtsz):
added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox
ox = ox + added_ox
x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None)
# the conv! # the conv!
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
if added_ox:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :-added_ox, ...]
ox = ox - added_ox
# undo hack for non multiples of 4 on C.rcout # undo hack for non multiples of 4 on C.rcout
if added_output_channels != 0: if added_output_channels != 0:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]