diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 37bd6411d1..4d00ea1190 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -646,7 +646,7 @@ jobs: - name: DEBUG=2 IMAGE=1 openpilot compile3 0.10.1 driving_vision run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.10.1 driving_vision - run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=19 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx - name: openpilot compile3 0.10.1 driving_policy run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.10.1 dmonitoring diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bb4784dc52..8be43cae8b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -395,7 +395,7 @@ jobs: llvm: 'true' - name: Test openpilot model kernel count and gate usage run: | - ALLOWED_KERNEL_COUNT=125 ALLOWED_READ_IMAGE=1389 ALLOWED_GATED_READ_IMAGE=101 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 + ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp16 run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916 - name: Test openpilot CL compile fp32 (test correctness) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 6f95518a2f..da1f3aeeea 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -11,25 +11,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL") @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageCopy(unittest.TestCase): - def test_image_copyout_1x8(self, img_type=dtypes.imagef): - it = Tensor.arange(32).cast(img_type((1,8,4))).realize() + def test_image_copyout_1x1(self, img_type=dtypes.imagef): + it = Tensor.arange(4).cast(img_type((1,1,4))).realize() buf = it.uop.buffer out = buf.as_buffer() - np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32)) + np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4)) @unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half") - def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh) + def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh) - def test_image_numpy_1x8(self, img_type=dtypes.imagef): - it = Tensor.arange(32).cast(img_type((1,8,4))).realize() - np.testing.assert_equal(it.numpy(), np.arange(32)) - def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh) + def test_image_numpy_1x1(self, img_type=dtypes.imagef): + it = Tensor.arange(4).cast(img_type((1,1,4))).realize() + np.testing.assert_equal(it.numpy(), np.arange(4)) + def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh) - def test_image_copyout_2x4(self): - it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize() + def test_image_copyout_2x3(self): + it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize() buf = it.uop.buffer out = buf.as_buffer() - np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4)) + np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4)) def test_image_roundtrip(self): sz = (4,2,4) @@ -46,9 +46,9 @@ class TestImageCopy(unittest.TestCase): @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageDType(unittest.TestCase): def test_image_and_back(self): - data = Tensor.randn(9*32*4).realize() + data = Tensor.randn(9*27*4).realize() tst = data.numpy() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() assert isinstance(it.uop.base.realized.dtype, ImageDType) np.testing.assert_equal(tst, it.numpy()) @@ -68,13 +68,13 @@ class TestImageDType(unittest.TestCase): np.testing.assert_equal(tst, it.numpy()) def test_shrink_load_float(self): - it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize() + it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize() imgv = it.numpy() np.testing.assert_equal(imgv[0:2], it[0:2].numpy()) def test_mul_stays_image(self): # NOTE: contiguous is needed otherwise this folds - it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).contiguous().realize() + it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize() out = (it*2).realize() assert isinstance(out.uop.base.realized.dtype, ImageDType) @@ -84,7 +84,7 @@ class TestImageDType(unittest.TestCase): np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6) def test_shrink_max(self): - it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize() + it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize() imgv = it.numpy() np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy()) @@ -103,19 +103,19 @@ class TestImageDType(unittest.TestCase): assert it.uop.base.realized._buf == b1 def test_no_lru_alloc(self): - data = Tensor.randn(9*32*4).realize() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() b1 = it.uop.base.realized._buf del it - it = data.reshape(9,32,4).pad_to(10, None, None).cast(dtypes.imagef((10,32,4))).contiguous().realize() + it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize() assert it.uop.base.realized._buf != b1 def test_no_lru_alloc_dtype(self): - data = Tensor.randn(9*32*4).realize() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() b1 = it.uop.base.realized._buf del it - it = data.cast(dtypes.imageh((9,32,4))).realize() + it = data.cast(dtypes.imageh((9,27,4))).realize() assert it.uop.base.realized._buf != b1 # issue caused by: don't realize image to image casts. this is part of a larger problem @@ -143,36 +143,36 @@ class TestImageDType(unittest.TestCase): @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageRealization(unittest.TestCase): def test_image_dtype_expand(self): - data = Tensor.randn(9*32*4).realize() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) - it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous().realize() + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize() self.assertEqual(it_expanded.dtype, dtypes.float32) def test_image_dtype_expand_and_back(self): - data = Tensor.randn(9*32*4).realize() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) - it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)) + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)) it2 = it_expanded.sum(3).realize() - self.assertEqual(it2.dtype, dtypes.imagef((9,32,4))) + self.assertEqual(it2.dtype, dtypes.imagef((9,27,4))) def test_image_alu_children(self): - data = Tensor.randn(9*32*4).realize() - it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize() - self.assertEqual(it.dtype, dtypes.imagef((9,32,4))) - it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous() + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous() alu1 = it_expanded+1 alu2 = it_expanded.sum(3) it_expanded.realize() # NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image - self.assertEqual(alu1.dtype, dtypes.imagef((9,32,4))) + self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4))) alu1.realize() self.assertEqual(alu1.dtype, dtypes.float32) # alu2 is back in image because it fits the dtype again - self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4))) + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) alu2.realize() - self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4))) + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) if __name__ == '__main__': unittest.main() diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 9e7b52d78d..4baebae6b7 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -24,9 +24,8 @@ def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None): initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE] inbufs = [x.uop.base.buffer for x in inputs] src = Device[Device.DEFAULT].renderer.render(uops) - aux = Device[Device.DEFAULT].renderer.aux(uops) if Device[Device.DEFAULT].renderer.has_aux else {} ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", - src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size, aux=aux)) + src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size)) ei.exec(outbufs+inbufs) return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs] diff --git a/test/test_uops.py b/test/test_uops.py index baaf646087..a92939d7c0 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -27,10 +27,9 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]: def _uops_to_prg(uops_list): uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer) src = Device[Device.DEFAULT].renderer.render(uops) - aux = Device[Device.DEFAULT].renderer.aux(uops) if Device[Device.DEFAULT].renderer.has_aux else {} has_local = Device[Device.DEFAULT].renderer.has_local return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops, - global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None, aux=aux)) + global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None)) def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp: uops.append(UOp(uop, dtype, tuple(src), arg)) diff --git a/tinygrad/device.py b/tinygrad/device.py index 322567801e..f8e5b5e915 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK -from tinygrad.dtype import ImageDType, PtrDType, DType, dtypes, _to_np_dtype +from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer # **************** Device **************** @@ -70,6 +70,7 @@ class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[l @dataclass(frozen=True, eq=True) class BufferSpec: # TODO: move device, size, dtype here? + image: ImageDType|None = None uncached: bool = False cpu_access: bool = False host: bool = False @@ -93,7 +94,8 @@ class Buffer: profile_events:list[ProfileEvent] = [] def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:BufferSpec|None=None, initial_value:bytes|None=None, uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False): - assert isinstance(dtype, DType) and (isinstance(dtype, ImageDType) or not isinstance(dtype, PtrDType)) + if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be? + else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType) self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0 if base is None: assert offset == 0, "base buffers can't have offset" @@ -171,7 +173,7 @@ class Buffer: return self.allocator._as_dmaref(self._buf) def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: # zero copy with as_buffer (disabled by default due to use after free) - if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and self.options is None: + if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None): return self.allocator._as_buffer(self._buf) assert not force_zero_copy, "force zero copy was passed, but copy is required" return self.copyout(memoryview(bytearray(self.nbytes))) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 016bce48dd..a2edd2ad62 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -48,7 +48,7 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops, global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None, - local_size=[1,1,1] if renderer.has_local else None, aux=renderer.aux(uops) if renderer.has_aux else {}) + local_size=[1,1,1] if renderer.has_local else None) # **************** Runners **************** @@ -86,7 +86,7 @@ class CompiledRunner(Runner): with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"): self.lib = Device[p.device].compiler.compile_cached(p.src) if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib) - self._prg = Device[p.device].runtime(p.function_name, self.lib, **p.aux) if prg is None else prg + self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg super().__init__(p.name, p.device, p.estimates) def __reduce__(self): return self.__class__, (self.p, self.lib) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 71a588b594..c63dbff3df 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -64,7 +64,6 @@ class ProgramSpec: device:str ast:UOp # save the base ast (this is method cache key) uops:list[UOp]|None=None - aux:dict=field(default_factory=dict) # filled in from uops (if we have uops) global_size:list[int]|None=None @@ -122,7 +121,6 @@ class Renderer: has_local: bool = True has_threads: bool = False has_shared: bool = True - has_aux: bool = False # additional program info, eg. image shapes # NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now @@ -134,4 +132,3 @@ class Renderer: def __reduce__(self): return self.__class__, () def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer") - def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index e68bdb8d8d..cdd1b3c89b 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -280,7 +280,6 @@ class ClangRenderer(CStyleLanguage): class OpenCLRenderer(CStyleLanguage): device = "CL" - has_aux = True # language options kernel_typedef = "__kernel void" @@ -309,8 +308,6 @@ class OpenCLRenderer(CStyleLanguage): if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or [])) return super().render_kernel(function_name, kernel, bufs, uops, prefix) - def aux(self, uops:list[UOp]): return {"buf_dtypes": [u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL]} - class IntelRenderer(OpenCLRenderer): device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void" tensor_cores = tc.intel diff --git a/tinygrad/runtime/ops_cl.py b/tinygrad/runtime/ops_cl.py index d6f7c9c4d4..64ef605467 100644 --- a/tinygrad/runtime/ops_cl.py +++ b/tinygrad/runtime/ops_cl.py @@ -5,7 +5,6 @@ from tinygrad.runtime.autogen import opencl as cl from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet -from tinygrad.dtype import ImageDType # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something OSX_TIMING_RATIO = (125/3) if OSX else 1.0 @@ -34,8 +33,8 @@ class CLCompiler(Compiler): return bytes(binary) class CLProgram: - def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]): - self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes + def __init__(self, device:CLDevice, name:str, lib:bytes): + self.dev, self.name, self.lib = device, name, lib self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)), to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(), errcode_ret := ctypes.c_int32()), errcode_ret) @@ -51,12 +50,7 @@ class CLProgram: def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None, vals:tuple[int, ...]=(), wait=False) -> float|None: - for i,(b,_) in enumerate(bufs): - if isinstance(dt:=self.buf_dtypes[i], ImageDType): - fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize]) - desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b) - b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status) - check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) + for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b))) for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v)))) if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size))) event = cl.cl_event() if wait else None @@ -72,15 +66,27 @@ class CLProgram: class CLAllocator(LRUAllocator['CLDevice']): def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]: + if options.image is not None: + return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE, + cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]), + options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options) return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options) @suppress_finalizing def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0])) def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview): - if mv_address(src) % 16: src = memoryview(bytearray(src)) - check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None)) + if dest[1].image is not None: + check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0), + (ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None)) + else: + if mv_address(src) % 16: src = memoryview(bytearray(src)) + check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None)) self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]): - check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None)) + if src[1].image is not None: + check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0), + (ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None)) + else: + check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None)) self.dev.synchronize() class CLDevice(Compiled): diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 03b06ecfae..b1284bbd1e 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -13,7 +13,6 @@ from tinygrad.renderer.nir import IR3Renderer from tinygrad.runtime.support.compiler_mesa import IR3Compiler from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC -from tinygrad.dtype import ImageDType from tinygrad.runtime.support.system import System if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import @@ -29,7 +28,6 @@ def _qreg_exec(__reg, __val=0, **kwargs): qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'}) def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length() -def ctz(v): return (v & -v).bit_length() - 1 def parity(val: int): for i in range(4,1,-1): val ^= val >> (1 << i) @@ -204,8 +202,8 @@ class QCOMArgsState(HCQArgsState): if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) for i, b in enumerate(bufs): - if (ti:=prg.tex_infos[i]) is not None: - obj = ti.desc if prg.buf_info[i].type is BUFTYPE_TEX else ti.ibo + if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}: + obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj) self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16)) @@ -227,24 +225,12 @@ class IR3ArgsState(HCQArgsState): self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off) class QCOMProgram(HCQProgram): - def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]): - self.tex_infos:list[QCOMTextureInfo|None] = [] - for dtype in buf_dtypes: - if isinstance(dtype, ImageDType): - imgw, imgh = dtype.shape[1], dtype.shape[0] - stride = imgw * 4 * dtype.itemsize - assert stride % 64 == 0 - tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if dtype.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT - desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh), - qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=stride, pitchalign=ctz(stride)-6), 0, 0, 0, - qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)] - self.tex_infos.append(QCOMTextureInfo(stride, stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])) - else: self.tex_infos.append(None) - + def __init__(self, dev: QCOMDevice, name: str, lib: bytes): self.dev: QCOMDevice = dev self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler) if self.NIR: + from tinygrad.runtime.autogen import mesa v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib) self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc @@ -338,17 +324,43 @@ class QCOMTextureInfo: class QCOMAllocator(HCQAllocatorBase): def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: - return self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size) + # Recalculate real size for texture + if options.image is not None: + imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize)) + pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6 + align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2)) - def _do_copy(self, src_addr, dest_addr, src_size, prof_text): - with cpu_profile(prof_text, self.dev.device, is_copy=True): ctypes.memmove(dest_addr, src_addr, src_size) + granularity = 128 if options.image.itemsize == 4 else 256 + pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0 + pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add + size = pitch * imgh - def _copyin(self, dest:HCQBuffer, src:memoryview): self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, f"TINY -> {self.dev.device}") + buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size) + + if options.image is not None: + tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT + desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh), + qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0, + *data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)] + + buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]]) + return buf + + def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0): + with cpu_profile(prof_text, self.dev.device, is_copy=True): + while src_off < src_size: + ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size) + src_off, dest_off = src_off+src_stride, dest_off+dest_stride + + def _copyin(self, dest:HCQBuffer, src:memoryview): + stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch) + self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}") def _copyout(self, dest:memoryview, src:HCQBuffer): self.dev.synchronize() - self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, f"{self.dev.device} -> TINY") + stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch) + self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY") def _as_buffer(self, src:HCQBuffer) -> memoryview: self.dev.synchronize() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 07ae071559..369675f4f7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3879,7 +3879,7 @@ class Tensor(OpMixin): return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: - base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4) + base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) @@ -3892,20 +3892,6 @@ class Tensor(OpMixin): w = w.pad_to(None, None, cin, None, None) x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix) - # hacks for pitch alignment - assert isinstance(ix, int) and isinstance(H, int) - added_width = 0 - if (ix*groups*cin) % (64 // dtsz): - added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix - ix = ix + added_width - x = x.pad_to(None, None, None, ix) - - added_weight = 0 - if (H*W*cin) % (64 // dtsz): - added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H - H = H + added_weight - w = w.pad_to(None, None, None, H, None) - # hack for non multiples of 4 on rcout added_output_channels = 0 if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): @@ -3925,18 +3911,13 @@ class Tensor(OpMixin): if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4))) x, w = x.contiguous(), w.contiguous() - if added_weight: w, H = w[:, :-added_weight, ...], H - added_weight - # expand out rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1) group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1) - x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo) + x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo) if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo) else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) - # undo pitch alignment hack - if added_width: x = x[:, :, :-added_width, ...] - # prepare input x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W) @@ -3944,20 +3925,9 @@ class Tensor(OpMixin): # prepare weights w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W)) - added_ox = 0 - assert isinstance(ox, int) and isinstance(cout, int) - if (ox * cout) % (64 // dtsz): - added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox - ox = ox + added_ox - x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None) - # the conv! ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) - if added_ox: - ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :-added_ox, ...] - ox = ox - added_ox - # undo hack for non multiples of 4 on C.rcout if added_output_channels != 0: ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]