mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
This reverts commit 2571a1eb47.
This commit is contained in:
committed by
GitHub
parent
2571a1eb47
commit
97103831c5
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -646,7 +646,7 @@ jobs:
|
||||
- name: DEBUG=2 IMAGE=1 openpilot compile3 0.10.1 driving_vision
|
||||
run: PYTHONPATH="." DEBUG=2 DEV=QCOM FLOAT16=1 IMAGE=1 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_vision
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=19 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_vision PYTHONPATH="." ASSERT_MIN_STEP_TIME=17 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_vision.onnx
|
||||
- name: openpilot compile3 0.10.1 driving_policy
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=4 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -395,7 +395,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=125 ALLOWED_READ_IMAGE=1389 ALLOWED_GATED_READ_IMAGE=101 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp16
|
||||
run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
|
||||
@@ -11,25 +11,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32))
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half")
|
||||
def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh)
|
||||
def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_numpy_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(32))
|
||||
def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh)
|
||||
def test_image_numpy_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(4))
|
||||
def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_copyout_2x4(self):
|
||||
it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize()
|
||||
def test_image_copyout_2x3(self):
|
||||
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4))
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
|
||||
|
||||
def test_image_roundtrip(self):
|
||||
sz = (4,2,4)
|
||||
@@ -46,9 +46,9 @@ class TestImageCopy(unittest.TestCase):
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
@@ -68,13 +68,13 @@ class TestImageDType(unittest.TestCase):
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_shrink_load_float(self):
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
|
||||
|
||||
def test_mul_stays_image(self):
|
||||
# NOTE: contiguous is needed otherwise this folds
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).contiguous().realize()
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.uop.base.realized.dtype, ImageDType)
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestImageDType(unittest.TestCase):
|
||||
np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6)
|
||||
|
||||
def test_shrink_max(self):
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
|
||||
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
|
||||
|
||||
@@ -103,19 +103,19 @@ class TestImageDType(unittest.TestCase):
|
||||
assert it.uop.base.realized._buf == b1
|
||||
|
||||
def test_no_lru_alloc(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.reshape(9,32,4).pad_to(10, None, None).cast(dtypes.imagef((10,32,4))).contiguous().realize()
|
||||
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
def test_no_lru_alloc_dtype(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imageh((9,32,4))).realize()
|
||||
it = data.cast(dtypes.imageh((9,27,4))).realize()
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
# issue caused by: don't realize image to image casts. this is part of a larger problem
|
||||
@@ -143,36 +143,36 @@ class TestImageDType(unittest.TestCase):
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous().realize()
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize()
|
||||
self.assertEqual(it_expanded.dtype, dtypes.float32)
|
||||
|
||||
def test_image_dtype_expand_and_back(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4))
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4))
|
||||
it2 = it_expanded.sum(3).realize()
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,32,4)))
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
def test_image_alu_children(self):
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous()
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous()
|
||||
alu1 = it_expanded+1
|
||||
alu2 = it_expanded.sum(3)
|
||||
it_expanded.realize()
|
||||
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,32,4)))
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4)))
|
||||
alu1.realize()
|
||||
self.assertEqual(alu1.dtype, dtypes.float32)
|
||||
# alu2 is back in image because it fits the dtype again
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
alu2.realize()
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -24,9 +24,8 @@ def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
aux = Device[Device.DEFAULT].renderer.aux(uops) if Device[Device.DEFAULT].renderer.has_aux else {}
|
||||
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
|
||||
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size, aux=aux))
|
||||
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
|
||||
@@ -27,10 +27,9 @@ def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
|
||||
def _uops_to_prg(uops_list):
|
||||
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer)
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
aux = Device[Device.DEFAULT].renderer.aux(uops) if Device[Device.DEFAULT].renderer.has_aux else {}
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None, aux=aux))
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
|
||||
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(src), arg))
|
||||
|
||||
@@ -6,7 +6,7 @@ import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored
|
||||
from tinygrad.helpers import Context, CCACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup, ContextVar
|
||||
from tinygrad.helpers import unwrap_class_type, suppress_finalizing, select_first_inited, VIZ, CPU_LLVM, CPU_LVP, NV_PTX, CUDA_PTX, NV_NAK
|
||||
from tinygrad.dtype import ImageDType, PtrDType, DType, dtypes, _to_np_dtype
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
@@ -70,6 +70,7 @@ class ProfileGraphEvent(ProfileEvent): ents:list[ProfileGraphEntry]; deps:list[l
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class BufferSpec:
|
||||
# TODO: move device, size, dtype here?
|
||||
image: ImageDType|None = None
|
||||
uncached: bool = False
|
||||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
@@ -93,7 +94,8 @@ class Buffer:
|
||||
profile_events:list[ProfileEvent] = []
|
||||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options:BufferSpec|None=None, initial_value:bytes|None=None,
|
||||
uop_refcount=0, base:Buffer|None=None, offset:int=0, preallocate=False):
|
||||
assert isinstance(dtype, DType) and (isinstance(dtype, ImageDType) or not isinstance(dtype, PtrDType))
|
||||
if isinstance(dtype, ImageDType): options = BufferSpec(image=dtype) # TODO: image hack shouldn't be here. where should it be?
|
||||
else: assert isinstance(dtype, DType) and not isinstance(dtype, PtrDType)
|
||||
self.device, self.size, self.dtype, self.options, self.offset, self.allocated_views = device, size, dtype, options, offset, 0
|
||||
if base is None:
|
||||
assert offset == 0, "base buffers can't have offset"
|
||||
@@ -171,7 +173,7 @@ class Buffer:
|
||||
return self.allocator._as_dmaref(self._buf)
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and self.options is None:
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, '_as_buffer') and (self.options is None or self.options.image is None):
|
||||
return self.allocator._as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
|
||||
@@ -48,7 +48,7 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
|
||||
return ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, renderer.device, ast, uops,
|
||||
global_size=[1,1,1] if renderer.has_local or renderer.has_threads else None,
|
||||
local_size=[1,1,1] if renderer.has_local else None, aux=renderer.aux(uops) if renderer.has_aux else {})
|
||||
local_size=[1,1,1] if renderer.has_local else None)
|
||||
|
||||
# **************** Runners ****************
|
||||
|
||||
@@ -86,7 +86,7 @@ class CompiledRunner(Runner):
|
||||
with cpu_profile(TracingKey(f"compile {p.name}", (p.function_name,)), "TINY"):
|
||||
self.lib = Device[p.device].compiler.compile_cached(p.src)
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib, **p.aux) if prg is None else prg
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.lib) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p, self.lib)
|
||||
|
||||
@@ -64,7 +64,6 @@ class ProgramSpec:
|
||||
device:str
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
aux:dict=field(default_factory=dict)
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:list[int]|None=None
|
||||
@@ -122,7 +121,6 @@ class Renderer:
|
||||
has_local: bool = True
|
||||
has_threads: bool = False
|
||||
has_shared: bool = True
|
||||
has_aux: bool = False # additional program info, eg. image shapes
|
||||
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
||||
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
||||
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
||||
@@ -134,4 +132,3 @@ class Renderer:
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")
|
||||
|
||||
@@ -280,7 +280,6 @@ class ClangRenderer(CStyleLanguage):
|
||||
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "CL"
|
||||
has_aux = True
|
||||
|
||||
# language options
|
||||
kernel_typedef = "__kernel void"
|
||||
@@ -309,8 +308,6 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def aux(self, uops:list[UOp]): return {"buf_dtypes": [u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL]}
|
||||
|
||||
class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
||||
tensor_cores = tc.intel
|
||||
|
||||
@@ -5,7 +5,6 @@ from tinygrad.runtime.autogen import opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
|
||||
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet
|
||||
from tinygrad.dtype import ImageDType
|
||||
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
@@ -34,8 +33,8 @@ class CLCompiler(Compiler):
|
||||
return bytes(binary)
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]):
|
||||
self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes
|
||||
def __init__(self, device:CLDevice, name:str, lib:bytes):
|
||||
self.dev, self.name, self.lib = device, name, lib
|
||||
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)),
|
||||
to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(),
|
||||
errcode_ret := ctypes.c_int32()), errcode_ret)
|
||||
@@ -51,12 +50,7 @@ class CLProgram:
|
||||
|
||||
def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None,
|
||||
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
||||
for i,(b,_) in enumerate(bufs):
|
||||
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
|
||||
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
|
||||
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b)
|
||||
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
|
||||
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
event = cl.cl_event() if wait else None
|
||||
@@ -72,15 +66,27 @@ class CLProgram:
|
||||
|
||||
class CLAllocator(LRUAllocator['CLDevice']):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]:
|
||||
if options.image is not None:
|
||||
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,
|
||||
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
|
||||
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
|
||||
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
|
||||
@suppress_finalizing
|
||||
def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
|
||||
def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview):
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
if dest[1].image is not None:
|
||||
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
|
||||
else:
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
|
||||
def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]):
|
||||
check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
|
||||
if src[1].image is not None:
|
||||
check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
|
||||
else:
|
||||
check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
|
||||
self.dev.synchronize()
|
||||
|
||||
class CLDevice(Compiled):
|
||||
|
||||
@@ -13,7 +13,6 @@ from tinygrad.renderer.nir import IR3Renderer
|
||||
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
|
||||
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.runtime.support.system import System
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
@@ -29,7 +28,6 @@ def _qreg_exec(__reg, __val=0, **kwargs):
|
||||
qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'})
|
||||
|
||||
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
|
||||
def ctz(v): return (v & -v).bit_length() - 1
|
||||
|
||||
def parity(val: int):
|
||||
for i in range(4,1,-1): val ^= val >> (1 << i)
|
||||
@@ -204,8 +202,8 @@ class QCOMArgsState(HCQArgsState):
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
for i, b in enumerate(bufs):
|
||||
if (ti:=prg.tex_infos[i]) is not None:
|
||||
obj = ti.desc if prg.buf_info[i].type is BUFTYPE_TEX else ti.ibo
|
||||
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
|
||||
obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo
|
||||
to_mv(self.buf.va_addr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
|
||||
self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=self.buf_info[i].offset+(0 if self.buf_info[i].type is BUFTYPE_BUF else 16))
|
||||
|
||||
@@ -227,24 +225,12 @@ class IR3ArgsState(HCQArgsState):
|
||||
self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off)
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]):
|
||||
self.tex_infos:list[QCOMTextureInfo|None] = []
|
||||
for dtype in buf_dtypes:
|
||||
if isinstance(dtype, ImageDType):
|
||||
imgw, imgh = dtype.shape[1], dtype.shape[0]
|
||||
stride = imgw * 4 * dtype.itemsize
|
||||
assert stride % 64 == 0
|
||||
tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if dtype.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=stride, pitchalign=ctz(stride)-6), 0, 0, 0,
|
||||
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
self.tex_infos.append(QCOMTextureInfo(stride, stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]]))
|
||||
else: self.tex_infos.append(None)
|
||||
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
||||
self.dev: QCOMDevice = dev
|
||||
self.name, self.lib, self.NIR = name, lib, isinstance(dev.compiler, IR3Compiler)
|
||||
|
||||
if self.NIR:
|
||||
from tinygrad.runtime.autogen import mesa
|
||||
v, cs, self.imm_vals, self.image = IR3Compiler.unpack_lib(lib)
|
||||
self.prg_offset, self.brnchstck, self.image_size, self.pvtmem, self.shmem = 0, v.branchstack, v.info.size, v.pvtmem_size, v.shared_size
|
||||
self.wgsz = alloc.offset_vec4 * 4 + 8 if (alloc:=cs.allocs.consts[mesa.IR3_CONST_ALLOC_DRIVER_PARAMS]).size_vec4 else 0xfc
|
||||
@@ -338,17 +324,43 @@ class QCOMTextureInfo:
|
||||
|
||||
class QCOMAllocator(HCQAllocatorBase):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
return self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
||||
# Recalculate real size for texture
|
||||
if options.image is not None:
|
||||
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, prof_text):
|
||||
with cpu_profile(prof_text, self.dev.device, is_copy=True): ctypes.memmove(dest_addr, src_addr, src_size)
|
||||
granularity = 128 if options.image.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
size = pitch * imgh
|
||||
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview): self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, f"TINY -> {self.dev.device}")
|
||||
buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
||||
|
||||
if options.image is not None:
|
||||
tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
|
||||
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
|
||||
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
|
||||
return buf
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):
|
||||
with cpu_profile(prof_text, self.dev.device, is_copy=True):
|
||||
while src_off < src_size:
|
||||
ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size)
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}")
|
||||
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.dev.synchronize()
|
||||
|
||||
self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, f"{self.dev.device} -> TINY")
|
||||
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY")
|
||||
|
||||
def _as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
self.dev.synchronize()
|
||||
|
||||
@@ -3879,7 +3879,7 @@ class Tensor(OpMixin):
|
||||
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
|
||||
base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4)
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
||||
@@ -3892,20 +3892,6 @@ class Tensor(OpMixin):
|
||||
w = w.pad_to(None, None, cin, None, None)
|
||||
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
# hacks for pitch alignment
|
||||
assert isinstance(ix, int) and isinstance(H, int)
|
||||
added_width = 0
|
||||
if (ix*groups*cin) % (64 // dtsz):
|
||||
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
|
||||
ix = ix + added_width
|
||||
x = x.pad_to(None, None, None, ix)
|
||||
|
||||
added_weight = 0
|
||||
if (H*W*cin) % (64 // dtsz):
|
||||
added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H
|
||||
H = H + added_weight
|
||||
w = w.pad_to(None, None, None, H, None)
|
||||
|
||||
# hack for non multiples of 4 on rcout
|
||||
added_output_channels = 0
|
||||
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
||||
@@ -3925,18 +3911,13 @@ class Tensor(OpMixin):
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
if added_weight: w, H = w[:, :-added_weight, ...], H - added_weight
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1)
|
||||
group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1)
|
||||
x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo)
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
# undo pitch alignment hack
|
||||
if added_width: x = x[:, :, :-added_width, ...]
|
||||
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
@@ -3944,20 +3925,9 @@ class Tensor(OpMixin):
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W))
|
||||
|
||||
added_ox = 0
|
||||
assert isinstance(ox, int) and isinstance(cout, int)
|
||||
if (ox * cout) % (64 // dtsz):
|
||||
added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox
|
||||
ox = ox + added_ox
|
||||
x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None)
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
|
||||
|
||||
if added_ox:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :-added_ox, ...]
|
||||
ox = ox - added_ox
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
||||
|
||||
Reference in New Issue
Block a user