diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 1d8bbc182e..764739306e 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -548,6 +548,8 @@ jobs: run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: DEBUG=2 IMAGE=1 openpilot compile3 0.10.1 driving_vision run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx + - name: IMAGE=1 openpilot compile3 0.10.1 driving_vision + run: BENCHMARK_LOG=image_1_openpilot_0_10_1_vision PYTHONPATH="." DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.10.1 driving_vision run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.10.1 driving_policy diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0400dee2b5..37df709378 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -394,7 +394,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + ALLOWED_KERNEL_COUNT=125 ALLOWED_READ_IMAGE=1389 ALLOWED_GATED_READ_IMAGE=101 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp16 run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp32 (test correctness) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 245e671db2..b21fa5b79e 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -10,25 +10,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL") @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageCopy(unittest.TestCase): - def test_image_copyout_1x1(self, img_type=dtypes.imagef): - it = Tensor.arange(4).cast(img_type((1,1,4))).realize() + def test_image_copyout_1x8(self, img_type=dtypes.imagef): + it = Tensor.arange(32).cast(img_type((1,8,4))).realize() buf = it.uop.buffer out = buf.as_buffer() - np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4)) + np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32)) @unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half") - def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh) + def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh) - def test_image_numpy_1x1(self, img_type=dtypes.imagef): - it = Tensor.arange(4).cast(img_type((1,1,4))).realize() - np.testing.assert_equal(it.numpy(), np.arange(4)) - def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh) + def test_image_numpy_1x8(self, img_type=dtypes.imagef): + it = Tensor.arange(32).cast(img_type((1,8,4))).realize() + np.testing.assert_equal(it.numpy(), np.arange(32)) + def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh) - def test_image_copyout_2x3(self): - it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize() + def test_image_copyout_2x4(self): + it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize() buf = it.uop.buffer out = buf.as_buffer() - np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4)) + np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4)) def test_image_roundtrip(self): sz = (4,2,4) @@ -105,9 +105,9 @@ class TestImageDType(unittest.TestCase): __validate(dtypes.imagef((1, 1)), 0x40) def test_image_and_back(self): - data = Tensor.randn(9*27*4).realize() + data = Tensor.randn(9*32*4).realize() tst = data.numpy() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() assert isinstance(it.uop.base.realized.dtype, ImageDType) np.testing.assert_equal(tst, it.numpy()) @@ -127,13 +127,13 @@ class TestImageDType(unittest.TestCase): np.testing.assert_equal(tst, it.numpy()) def test_shrink_load_float(self): - it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize() + it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize() imgv = it.numpy() np.testing.assert_equal(imgv[0:2], it[0:2].numpy()) def test_mul_stays_image(self): # NOTE: contiguous is needed otherwise this folds - it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize() + it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).contiguous().realize() out = (it*2).realize() assert isinstance(out.uop.base.realized.dtype, ImageDType) @@ -143,7 +143,7 @@ class TestImageDType(unittest.TestCase): np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6) def test_shrink_max(self): - it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize() + it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize() imgv = it.numpy() np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy()) @@ -162,19 +162,19 @@ class TestImageDType(unittest.TestCase): assert it.uop.base.realized._buf == b1 def test_no_lru_alloc(self): - data = Tensor.randn(9*27*4).realize() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + data = Tensor.randn(9*32*4).realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() b1 = it.uop.base.realized._buf del it - it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize() + it = data.reshape(9,32,4).pad_to(10, None, None).cast(dtypes.imagef((10,32,4))).contiguous().realize() assert it.uop.base.realized._buf != b1 def test_no_lru_alloc_dtype(self): - data = Tensor.randn(9*27*4).realize() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + data = Tensor.randn(9*32*4).realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() b1 = it.uop.base.realized._buf del it - it = data.cast(dtypes.imageh((9,27,4))).realize() + it = data.cast(dtypes.imageh((9,32,4))).realize() assert it.uop.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem @@ -202,36 +202,36 @@ class TestImageDType(unittest.TestCase): @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageRealization(unittest.TestCase): def test_image_dtype_expand(self): - data = Tensor.randn(9*27*4).realize() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) - it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize() + data = Tensor.randn(9*32*4).realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) + it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous().realize() self.assertEqual(it_expanded.dtype, dtypes.float32) def test_image_dtype_expand_and_back(self): - data = Tensor.randn(9*27*4).realize() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) - it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)) + data = Tensor.randn(9*32*4).realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) + it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)) it2 = it_expanded.sum(3).realize() - self.assertEqual(it2.dtype, dtypes.imagef((9,27,4))) + self.assertEqual(it2.dtype, dtypes.imagef((9,32,4))) def test_image_alu_children(self): - data = Tensor.randn(9*27*4).realize() - it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) - it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous() + data = Tensor.randn(9*32*4).realize() + it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) + it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous() alu1 = it_expanded+1 alu2 = it_expanded.sum(3) it_expanded.realize() # NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image - self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4))) + self.assertEqual(alu1.dtype, dtypes.imagef((9,32,4))) alu1.realize() self.assertEqual(alu1.dtype, dtypes.float32) # alu2 is back in image because it fits the dtype again - self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4))) alu2.realize() - self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4))) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index 4f6903debc..8fe05bb4a7 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -15,7 +15,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \ ReduceContext, correct_load_store, pm_render, pm_add_loads -from tinygrad.codegen.opt.postrange import apply_opts +from tinygrad.codegen.opt.postrange import apply_opts, make_images from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize @@ -53,6 +53,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - # split store range (only on CPU for now) sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges") + # create image buffers + sink = make_images(sink, ren) + # do postrange optimization, BEAM or hand_coded_optimizations sink = apply_opts(sink, ren) @@ -133,7 +136,7 @@ def do_linearize(prg:UOp, sink:UOp) -> UOp: def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp: src = ctx.render(list(lin.src)) - return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),)) + return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),), arg=ctx.aux(list(lin.src)) if ctx.has_aux else prg.arg) def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None: if ctx.compiler is None: return None diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index fd86308a95..e0ea718385 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten -from tinygrad.helpers import ALLOW_TF32, count +from tinygrad.helpers import IMAGE, ALLOW_TF32, count from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -349,3 +349,24 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp: if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice): k = hand_coded_optimizations(k) return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None) + +# create image buffers +def make_images(ast:UOp, ren:Renderer) -> UOp: + if IMAGE == 1 and ren.device in {"QCOM", "CL"}: + dg_types: dict = {} + def make_image(ctx, dg): + if (dt:=dg.dtype).base is dtypes.float and not isinstance(dt, ImageDType) and dt.size < 65536 and dt.nbytes() % 64 == 0: + ctx[dg.arg] = dt + return dg.replace(dtype=dtypes.imagef((1, dt.size // 4, 4), dt.nbytes())) + + ast = graph_rewrite(ast, PatternMatcher([(UPat(Ops.DEFINE_GLOBAL, name="dg"), make_image)]), ctx=dg_types, name="create image buffers") + + # undo unfoldable stores + def undo_image_store(ctx, st, idx, dg): + if dg.arg in ctx and not any(c.op is Ops.RANGE and (c.vmax+1)%4 == 0 for c in idx.src[1].get_idx().split_uop(Ops.ADD)): + return st.replace(src=(idx.replace(src=(dg.replace(dtype=ctx[dg.arg]),)+idx.src[1:]),)+st.src[1:]) + + ast = graph_rewrite(ast, PatternMatcher([ + (UPat(Ops.DEFINE_GLOBAL, name="dg").index(UPat(), name="idx").store(UPat(), name="st"), undo_image_store) + ]), ctx=dg_types, name="remove unfoldable image stores") + return ast diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 817685fb14..e539c61738 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -94,12 +94,14 @@ class PtrDType(DType): @dataclass(frozen=True, eq=False) class ImageDType(PtrDType): shape: tuple[int, ...] = () # shape of the Image + _pitch: int = -1 def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: assert addrspace == AddrSpace.GLOBAL, "images can't be local" return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '') @property def pitch(self): + if self._pitch != -1: return self._pitch imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize)) pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6 align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2)) @@ -181,9 +183,9 @@ class dtypes: # NOTE: these are image dtypes @staticmethod - def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) + def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) @staticmethod - def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp) + def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch) default_float: ClassVar[DType] = float32 default_int: ClassVar[DType] = int32 diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index de89a33e02..a5f6e6e3ed 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -45,7 +45,7 @@ class CompiledRunner(Runner): self.p:ProgramSpec = p assert self.p.lib is not None if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib) - self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg + self._prg = Device[p.device].runtime(p.function_name, self.p.lib, *p.aux) if prg is None else prg super().__init__(p.name, p.device, p.estimates) def __reduce__(self): return self.__class__, (self.p,) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a730237764..b6acf778f4 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -168,6 +168,7 @@ class ContextVar: ContextVar._cache[key] = self self.value, self.key = getenv(key, default_value), key def __bool__(self): return bool(self.value) + def __eq__(self, x): return self.value == x def __ge__(self, x): return self.value >= x def __gt__(self, x): return self.value > x def __lt__(self, x): return self.value < x diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 0d69e1df7a..428cbe7c23 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -66,6 +66,7 @@ class ProgramSpec: ast:UOp # save the base ast (this is method cache key) uops:list[UOp]|None=None lib:bytes|None=None + aux:list=field(default_factory=list) # filled in from uops (via from_uop) global_size:list[int]=field(default_factory=lambda: [1,1,1]) @@ -123,7 +124,7 @@ class ProgramSpec: # TODO: this cast is wrong, u.src[0].ssimplify() can be sint if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify()) - return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size, + return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, list(prg.arg) if prg.arg else [], global_size, local_size, sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins))) class Renderer: @@ -134,6 +135,7 @@ class Renderer: has_local: bool = True has_threads: bool = False has_shared: bool = True + has_aux: bool = False # additional program info, eg. image shapes # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now @@ -146,3 +148,4 @@ class Renderer: def __reduce__(self): return self.__class__, () def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") + def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index fd078eb535..6106efbf15 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -286,6 +286,7 @@ class ClangJITRenderer(ClangRenderer): class OpenCLRenderer(CStyleLanguage): device = "CL" + has_aux = True # language options kernel_typedef = "__kernel void" @@ -314,6 +315,8 @@ class OpenCLRenderer(CStyleLanguage): if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or [])) return super().render_kernel(function_name, kernel, bufs, uops, prefix) + def aux(self, uops:list[UOp]): return (tuple(u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL),) + class IntelRenderer(OpenCLRenderer): device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void" tensor_cores = tc.intel diff --git a/tinygrad/runtime/ops_cl.py b/tinygrad/runtime/ops_cl.py index 64ef605467..d6f7c9c4d4 100644 --- a/tinygrad/runtime/ops_cl.py +++ b/tinygrad/runtime/ops_cl.py @@ -5,6 +5,7 @@ from tinygrad.runtime.autogen import opencl as cl from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet +from tinygrad.dtype import ImageDType # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something OSX_TIMING_RATIO = (125/3) if OSX else 1.0 @@ -33,8 +34,8 @@ class CLCompiler(Compiler): return bytes(binary) class CLProgram: - def __init__(self, device:CLDevice, name:str, lib:bytes): - self.dev, self.name, self.lib = device, name, lib + def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]): + self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)), to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(), errcode_ret := ctypes.c_int32()), errcode_ret) @@ -50,7 +51,12 @@ class CLProgram: def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None, vals:tuple[int, ...]=(), wait=False) -> float|None: - for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) + for i,(b,_) in enumerate(bufs): + if isinstance(dt:=self.buf_dtypes[i], ImageDType): + fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize]) + desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b) + b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status) + check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))) if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) event = cl.cl_event() if wait else None @@ -66,27 +72,15 @@ class CLProgram: class CLAllocator(LRUAllocator['CLDevice']): def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]: - if options.image is not None: - return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE, - cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]), - options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options) return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options) @suppress_finalizing def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0])) def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview): - if dest[1].image is not None: - check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0), - (ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None)) - else: - if mv_address(src) % 16: src = memoryview(bytearray(src)) - check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None)) + if mv_address(src) % 16: src = memoryview(bytearray(src)) + check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None)) self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]): - if src[1].image is not None: - check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0), - (ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None)) - else: - check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None)) + check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None)) self.dev.synchronize() class CLDevice(Compiled): diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index deec87f582..fee8b7eb58 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -11,6 +11,7 @@ from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.renderer.nir import IR3Renderer from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC +from tinygrad.dtype import ImageDType from tinygrad.runtime.support.system import System if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import @@ -192,8 +193,9 @@ class QCOMArgsState(HCQArgsState): super().__init__(buf, prg, bufs, vals=vals) ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size) - ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None] - ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:]) + ubos = [b for i,b in enumerate(bufs) if not isinstance(prg.buf_dtypes[i], ImageDType)] + uavs = [(i,b) for i,b in enumerate(bufs) if isinstance(prg.buf_dtypes[i], ImageDType)] + ibos, texs = uavs[:prg.ibo_cnt], uavs[prg.ibo_cnt:] for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little') if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) @@ -205,19 +207,19 @@ class QCOMArgsState(HCQArgsState): for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)]) def _tex(b, ibo=False): - fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT + fmt = mesa.FMT6_32_32_32_32_FLOAT if (img:=b[1].image or prg.buf_dtypes[b[0]]).itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt), - qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]), - qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr), + qreg.a6xx_tex_const_1(width=img.shape[1], height=img.shape[0]), + qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=img.pitch, pitchalign=ctz(img.pitch)-6), 0, *data64_le(b[1].va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0] self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off) self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off) class QCOMProgram(HCQProgram): - def __init__(self, dev: QCOMDevice, name: str, lib: bytes): + def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]): self.dev: QCOMDevice = dev - self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer) + self.buf_dtypes, self.name, self.lib, self.NIR = buf_dtypes, name, lib, isinstance(dev.renderer, IR3Renderer) if self.NIR: from tinygrad.runtime.support.compiler_mesa import IR3Compiler @@ -313,7 +315,7 @@ class QCOMTextureInfo: class QCOMAllocator(HCQAllocatorBase): def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer: # Recalculate real size for texture - if opts.image is not None: size = opts.image.pitch* opts.image.shape[0] + if opts.image is not None: size = opts.image.pitch * opts.image.shape[0] return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image) def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index fe9bba840e..6ef78c50f4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3885,7 +3885,7 @@ class Tensor(OpMixin): return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: - base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef + base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4) (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) @@ -3898,6 +3898,20 @@ class Tensor(OpMixin): w = w.pad_to(None, None, cin, None, None) x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix) + # hacks for pitch alignment + assert isinstance(ix, int) and isinstance(H, int) + added_width = 0 + if (ix*groups*cin) % (64 // dtsz): + added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix + ix = ix + added_width + x = x.pad_to(None, None, None, ix) + + added_weight = 0 + if (H*W*cin) % (64 // dtsz): + added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H + H = H + added_weight + w = w.pad_to(None, None, None, H, None) + # hack for non multiples of 4 on rcout added_output_channels = 0 if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): @@ -3917,13 +3931,18 @@ class Tensor(OpMixin): if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) x, w = x.contiguous(), w.contiguous() + if added_weight: w, H = w[:, :-added_weight, ...], H - added_weight + # expand out rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1) group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1) - x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) + x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo) if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) + # undo pitch alignment hack + if added_width: x = x[:, :, :-added_width, ...] + # prepare input x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W) @@ -3931,9 +3950,20 @@ class Tensor(OpMixin): # prepare weights w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W)) + added_ox = 0 + assert isinstance(ox, int) and isinstance(cout, int) + if (ox * cout) % (64 // dtsz): + added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox + ox = ox + added_ox + x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None) + # the conv! ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) + if added_ox: + ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :-added_ox, ...] + ox = ox - added_ox + # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]