mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-05 05:04:27 -05:00
IMAGE=1 creates "dynamic" images (#13769)
* remove image from BufferSpec
* cl tiny_gemm (64) works
* mypy
* padding
* openpilot CL
* reshape properly
* remove extra qcom checks
* pad output
* mypy
* update compile test
* move undo
* TestImageCopy valid images
* TestImageRealization valid images
* TestImageDType valid images
* cleanups
* test_renderer_failures
* ruff
* mypy
* simplify ops_qcom
* bump step time
* Revert "bump step time"
This reverts commit 75a037c7d0.
* "dynamic textures" are optional
* a start
* IMAGE=1 works, no FLOAT16
* fast but wrong
* mypy
* some fixes
* better
* works
* refactor
* oops
This commit is contained in:
committed by
GitHub
parent
61dc70f1a8
commit
9dc524536f
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -394,7 +394,7 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=123 ALLOWED_READ_IMAGE=1397 ALLOWED_GATED_READ_IMAGE=94 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
ALLOWED_KERNEL_COUNT=125 ALLOWED_READ_IMAGE=1389 ALLOWED_GATED_READ_IMAGE=101 FLOAT16=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp16
|
||||
run: FLOAT16=1 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/cf6376aa9a090f0da26c280ef69eabf9bbdd51d1faac9ed392919c3db69be916
|
||||
- name: Test openpilot CL compile fp32 (test correctness)
|
||||
|
||||
@@ -10,25 +10,25 @@ IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
def test_image_copyout_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
def test_image_copyout_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(4))
|
||||
np.testing.assert_equal(out.cast(it.dtype.fmt).tolist(), np.arange(32))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half, device="PYTHON"), "need half")
|
||||
def test_imageh_copyout_1x1(self): self.test_image_copyout_1x1(img_type=dtypes.imageh)
|
||||
def test_imageh_copyout_1x8(self): self.test_image_copyout_1x8(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_numpy_1x1(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(4).cast(img_type((1,1,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(4))
|
||||
def test_imageh_numpy_1x1(self): self.test_image_numpy_1x1(img_type=dtypes.imageh)
|
||||
def test_image_numpy_1x8(self, img_type=dtypes.imagef):
|
||||
it = Tensor.arange(32).cast(img_type((1,8,4))).realize()
|
||||
np.testing.assert_equal(it.numpy(), np.arange(32))
|
||||
def test_imageh_numpy_1x8(self): self.test_image_numpy_1x8(img_type=dtypes.imageh)
|
||||
|
||||
def test_image_copyout_2x3(self):
|
||||
it = Tensor.arange(2*3*4).cast(dtypes.imagef((2,3,4))).realize()
|
||||
def test_image_copyout_2x4(self):
|
||||
it = Tensor.arange(2*4*4).cast(dtypes.imagef((2,4,4))).realize()
|
||||
buf = it.uop.buffer
|
||||
out = buf.as_buffer()
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*3*4))
|
||||
np.testing.assert_equal(out.cast('f').tolist(), np.arange(2*4*4))
|
||||
|
||||
def test_image_roundtrip(self):
|
||||
sz = (4,2,4)
|
||||
@@ -105,9 +105,9 @@ class TestImageDType(unittest.TestCase):
|
||||
__validate(dtypes.imagef((1, 1)), 0x40)
|
||||
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
tst = data.numpy()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
assert isinstance(it.uop.base.realized.dtype, ImageDType)
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
@@ -127,13 +127,13 @@ class TestImageDType(unittest.TestCase):
|
||||
np.testing.assert_equal(tst, it.numpy())
|
||||
|
||||
def test_shrink_load_float(self):
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).realize()
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(imgv[0:2], it[0:2].numpy())
|
||||
|
||||
def test_mul_stays_image(self):
|
||||
# NOTE: contiguous is needed otherwise this folds
|
||||
it = Tensor.randn(4).cast(dtypes.imagef((1,1,4))).contiguous().realize()
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).contiguous().realize()
|
||||
out = (it*2).realize()
|
||||
assert isinstance(out.uop.base.realized.dtype, ImageDType)
|
||||
|
||||
@@ -143,7 +143,7 @@ class TestImageDType(unittest.TestCase):
|
||||
np.testing.assert_allclose(np.sum(itn), it.sum().numpy(), rtol=1e-6)
|
||||
|
||||
def test_shrink_max(self):
|
||||
it = Tensor.randn(8).cast(dtypes.imagef((1,2,4))).realize()
|
||||
it = Tensor.randn(16).cast(dtypes.imagef((1,4,4))).realize()
|
||||
imgv = it.numpy()
|
||||
np.testing.assert_equal(np.maximum(imgv[0:3], 0), it[0:3].relu().numpy())
|
||||
|
||||
@@ -162,19 +162,19 @@ class TestImageDType(unittest.TestCase):
|
||||
assert it.uop.base.realized._buf == b1
|
||||
|
||||
def test_no_lru_alloc(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imagef((10,27,4))).contiguous().realize()
|
||||
it = data.reshape(9,32,4).pad_to(10, None, None).cast(dtypes.imagef((10,32,4))).contiguous().realize()
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
def test_no_lru_alloc_dtype(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
b1 = it.uop.base.realized._buf
|
||||
del it
|
||||
it = data.cast(dtypes.imageh((9,27,4))).realize()
|
||||
it = data.cast(dtypes.imageh((9,32,4))).realize()
|
||||
assert it.uop.base.realized._buf != b1
|
||||
|
||||
# issue caused by: don't realize image to image casts. this is part of a larger problem
|
||||
@@ -202,36 +202,36 @@ class TestImageDType(unittest.TestCase):
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize()
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous().realize()
|
||||
self.assertEqual(it_expanded.dtype, dtypes.float32)
|
||||
|
||||
def test_image_dtype_expand_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4))
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4))
|
||||
it2 = it_expanded.sum(3).realize()
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4)))
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,32,4)))
|
||||
|
||||
def test_image_alu_children(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous()
|
||||
data = Tensor.randn(9*32*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,32,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,32,4)))
|
||||
it_expanded = it.reshape((9,32,4,1)).expand((9,32,4,4)).contiguous()
|
||||
alu1 = it_expanded+1
|
||||
alu2 = it_expanded.sum(3)
|
||||
it_expanded.realize()
|
||||
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4)))
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,32,4)))
|
||||
alu1.realize()
|
||||
self.assertEqual(alu1.dtype, dtypes.float32)
|
||||
# alu2 is back in image because it fits the dtype again
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
|
||||
alu2.realize()
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,32,4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.uop.decompositions import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.late.expander import expander, pm_pre_expander, pm_group_for_reduce
|
||||
from tinygrad.codegen.late.devectorizer import load_store_folding, load_store_indexing, devectorize, pm_reduce, \
|
||||
ReduceContext, correct_load_store, pm_render, pm_add_loads
|
||||
from tinygrad.codegen.opt.postrange import apply_opts
|
||||
from tinygrad.codegen.opt.postrange import apply_opts, make_images
|
||||
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse, pm_split_store
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
@@ -53,6 +53,9 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
# split store range (only on CPU for now)
|
||||
sink = graph_rewrite(sink, pm_split_store, ctx=ren.device, name="cut store ranges")
|
||||
|
||||
# create image buffers
|
||||
sink = make_images(sink, ren)
|
||||
|
||||
# do postrange optimization, BEAM or hand_coded_optimizations
|
||||
sink = apply_opts(sink, ren)
|
||||
|
||||
@@ -133,7 +136,7 @@ def do_linearize(prg:UOp, sink:UOp) -> UOp:
|
||||
|
||||
def do_render(ctx:Renderer, prg:UOp, lin:UOp) -> UOp:
|
||||
src = ctx.render(list(lin.src))
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),))
|
||||
return prg.replace(src=prg.src + (UOp(Ops.SOURCE, arg=src),), arg=ctx.aux(list(lin.src)) if ctx.has_aux else prg.arg)
|
||||
|
||||
def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
|
||||
if ctx.compiler is None: return None
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import ALLOW_TF32, count
|
||||
from tinygrad.helpers import IMAGE, ALLOW_TF32, count
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -349,3 +349,24 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
|
||||
if not any(u.op is Ops.BUFFERIZE for u in ast.backward_slice):
|
||||
k = hand_coded_optimizations(k)
|
||||
return k.get_optimized_ast(name_override=ast.arg.name if ast.arg is not None and ast.arg.name != "test" else None)
|
||||
|
||||
# create image buffers
|
||||
def make_images(ast:UOp, ren:Renderer) -> UOp:
|
||||
if IMAGE == 1 and ren.device in {"QCOM", "CL"}:
|
||||
dg_types: dict = {}
|
||||
def make_image(ctx, dg):
|
||||
if (dt:=dg.dtype).base is dtypes.float and not isinstance(dt, ImageDType) and dt.size < 65536 and dt.nbytes() % 64 == 0:
|
||||
ctx[dg.arg] = dt
|
||||
return dg.replace(dtype=dtypes.imagef((1, dt.size // 4, 4), dt.nbytes()))
|
||||
|
||||
ast = graph_rewrite(ast, PatternMatcher([(UPat(Ops.DEFINE_GLOBAL, name="dg"), make_image)]), ctx=dg_types, name="create image buffers")
|
||||
|
||||
# undo unfoldable stores
|
||||
def undo_image_store(ctx, st, idx, dg):
|
||||
if dg.arg in ctx and not any(c.op is Ops.RANGE and (c.vmax+1)%4 == 0 for c in idx.src[1].get_idx().split_uop(Ops.ADD)):
|
||||
return st.replace(src=(idx.replace(src=(dg.replace(dtype=ctx[dg.arg]),)+idx.src[1:]),)+st.src[1:])
|
||||
|
||||
ast = graph_rewrite(ast, PatternMatcher([
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="dg").index(UPat(), name="idx").store(UPat(), name="st"), undo_image_store)
|
||||
]), ctx=dg_types, name="remove unfoldable image stores")
|
||||
return ast
|
||||
|
||||
@@ -94,12 +94,14 @@ class PtrDType(DType):
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class ImageDType(PtrDType):
|
||||
shape: tuple[int, ...] = () # shape of the Image
|
||||
_pitch: int = -1
|
||||
def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
|
||||
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
||||
return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
@property
|
||||
def pitch(self):
|
||||
if self._pitch != -1: return self._pitch
|
||||
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
||||
@@ -181,9 +183,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
||||
def imageh(shp, pitch=-1): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp)
|
||||
def imagef(shp, pitch=-1): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, AddrSpace.GLOBAL, 1, prod(shp), shp, pitch)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
@@ -45,7 +45,7 @@ class CompiledRunner(Runner):
|
||||
self.p:ProgramSpec = p
|
||||
assert self.p.lib is not None
|
||||
if DEBUG >= 7: Device[p.device].compiler.disassemble(self.p.lib)
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.p.lib) if prg is None else prg
|
||||
self._prg = Device[p.device].runtime(p.function_name, self.p.lib, *p.aux) if prg is None else prg
|
||||
super().__init__(p.name, p.device, p.estimates)
|
||||
|
||||
def __reduce__(self): return self.__class__, (self.p,)
|
||||
|
||||
@@ -168,6 +168,7 @@ class ContextVar:
|
||||
ContextVar._cache[key] = self
|
||||
self.value, self.key = getenv(key, default_value), key
|
||||
def __bool__(self): return bool(self.value)
|
||||
def __eq__(self, x): return self.value == x
|
||||
def __ge__(self, x): return self.value >= x
|
||||
def __gt__(self, x): return self.value > x
|
||||
def __lt__(self, x): return self.value < x
|
||||
|
||||
@@ -66,6 +66,7 @@ class ProgramSpec:
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:list[UOp]|None=None
|
||||
lib:bytes|None=None
|
||||
aux:list=field(default_factory=list)
|
||||
|
||||
# filled in from uops (via from_uop)
|
||||
global_size:list[int]=field(default_factory=lambda: [1,1,1])
|
||||
@@ -123,7 +124,7 @@ class ProgramSpec:
|
||||
# TODO: this cast is wrong, u.src[0].ssimplify() can be sint
|
||||
if special_size is not None: special_size[int(u.arg[-1])] = cast(int, u.src[0].ssimplify())
|
||||
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, global_size, local_size,
|
||||
return ProgramSpec(sink.arg.name, source.arg, device.arg, sink, uops, lib, list(prg.arg) if prg.arg else [], global_size, local_size,
|
||||
sorted(_vars, key=lambda v: v.arg), sorted(dedup(_globals)), sorted(dedup(outs)), sorted(dedup(ins)))
|
||||
|
||||
class Renderer:
|
||||
@@ -134,6 +135,7 @@ class Renderer:
|
||||
has_local: bool = True
|
||||
has_threads: bool = False
|
||||
has_shared: bool = True
|
||||
has_aux: bool = False # additional program info, eg. image shapes
|
||||
# NOTE: these two should be in (x,y,z) order to match the max_sizes argument in get_grouped_dims
|
||||
global_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
||||
local_max: tuple[int, ...]|None = (0x8FFFFFFF,) * (3) # TODO: Ops.SPECIAL int32 indexes right now
|
||||
@@ -146,3 +148,4 @@ class Renderer:
|
||||
|
||||
def __reduce__(self): return self.__class__, ()
|
||||
def render(self, uops:list[UOp]) -> str: raise NotImplementedError("needs a renderer")
|
||||
def aux(self, uops:list[UOp]) -> dict: raise NotImplementedError("needs aux")
|
||||
|
||||
@@ -286,6 +286,7 @@ class ClangJITRenderer(ClangRenderer):
|
||||
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "CL"
|
||||
has_aux = True
|
||||
|
||||
# language options
|
||||
kernel_typedef = "__kernel void"
|
||||
@@ -314,6 +315,8 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
if any(uop.dtype.base == dtypes.half for uop in uops): prefix = (["#pragma OPENCL EXTENSION cl_khr_fp16 : enable"] + (prefix or []))
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def aux(self, uops:list[UOp]): return (tuple(u.dtype for u in uops if u.op == Ops.DEFINE_GLOBAL),)
|
||||
|
||||
class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
||||
tensor_cores = tc.intel
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.runtime.autogen import opencl as cl
|
||||
from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG, mv_address, suppress_finalizing
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer, IntelRenderer
|
||||
from tinygrad.device import BufferSpec, LRUAllocator, Compiled, Compiler, CompileError, CompilerPair, CompilerSet
|
||||
from tinygrad.dtype import ImageDType
|
||||
|
||||
# see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
|
||||
OSX_TIMING_RATIO = (125/3) if OSX else 1.0
|
||||
@@ -33,8 +34,8 @@ class CLCompiler(Compiler):
|
||||
return bytes(binary)
|
||||
|
||||
class CLProgram:
|
||||
def __init__(self, device:CLDevice, name:str, lib:bytes):
|
||||
self.dev, self.name, self.lib = device, name, lib
|
||||
def __init__(self, device:CLDevice, name:str, lib:bytes, buf_dtypes=[]):
|
||||
self.dev, self.name, self.lib, self.buf_dtypes = device, name, lib, buf_dtypes
|
||||
self.program = checked(cl.clCreateProgramWithBinary(device.context, 1, device.device_id, (ctypes.c_size_t * 1)(len(lib)),
|
||||
to_char_p_p([lib], ctypes.c_ubyte), binary_status := ctypes.c_int32(),
|
||||
errcode_ret := ctypes.c_int32()), errcode_ret)
|
||||
@@ -50,7 +51,12 @@ class CLProgram:
|
||||
|
||||
def __call__(self, *bufs:tuple[ctypes._CData, BufferSpec], global_size:tuple[int,int,int]=(1,1,1), local_size:tuple[int,int,int]|None=None,
|
||||
vals:tuple[int, ...]=(), wait=False) -> float|None:
|
||||
for i,(b,_) in enumerate(bufs): check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,(b,_) in enumerate(bufs):
|
||||
if isinstance(dt:=self.buf_dtypes[i], ImageDType):
|
||||
fmt = cl.cl_image_format(cl.CL_RGBA, {2:cl.CL_HALF_FLOAT, 4:cl.CL_FLOAT}[dt.itemsize])
|
||||
desc = cl.cl_image_desc(cl.CL_MEM_OBJECT_IMAGE2D, dt.shape[1], dt.shape[0], buffer=b)
|
||||
b = checked(cl.clCreateImage(self.dev.context, cl.CL_MEM_READ_WRITE, fmt, desc, None, status:=ctypes.c_int32()), status)
|
||||
check(cl.clSetKernelArg(self.kernel, i, ctypes.sizeof(b), ctypes.byref(b)))
|
||||
for i,v in enumerate(vals,start=len(bufs)): check(cl.clSetKernelArg(self.kernel, i, 4, ctypes.byref(ctypes.c_int32(v))))
|
||||
if local_size is not None: global_size = cast(tuple[int,int,int], tuple(int(g*l) for g,l in zip(global_size, local_size)))
|
||||
event = cl.cl_event() if wait else None
|
||||
@@ -66,27 +72,15 @@ class CLProgram:
|
||||
|
||||
class CLAllocator(LRUAllocator['CLDevice']):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> tuple[ctypes._CData, BufferSpec]:
|
||||
if options.image is not None:
|
||||
return (checked(cl.clCreateImage2D(self.dev.context, cl.CL_MEM_READ_WRITE,
|
||||
cl.cl_image_format(cl.CL_RGBA, {2: cl.CL_HALF_FLOAT, 4: cl.CL_FLOAT}[options.image.itemsize]),
|
||||
options.image.shape[1], options.image.shape[0], 0, None, status := ctypes.c_int32()), status), options)
|
||||
return (checked(cl.clCreateBuffer(self.dev.context, cl.CL_MEM_READ_WRITE, size, None, status := ctypes.c_int32()), status), options)
|
||||
@suppress_finalizing
|
||||
def _free(self, opaque:tuple[ctypes._CData, BufferSpec], options:BufferSpec): check(cl.clReleaseMemObject(opaque[0]))
|
||||
def _copyin(self, dest:tuple[ctypes._CData, BufferSpec], src:memoryview):
|
||||
if dest[1].image is not None:
|
||||
check(cl.clEnqueueWriteImage(self.dev.queue, dest[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(dest[1].image.shape[1],dest[1].image.shape[0],1), 0, 0, from_mv(src), 0, None, None))
|
||||
else:
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
if mv_address(src) % 16: src = memoryview(bytearray(src))
|
||||
check(cl.clEnqueueWriteBuffer(self.dev.queue, dest[0], False, 0, len(src)*src.itemsize, from_mv(src), 0, None, None))
|
||||
self.dev.pending_copyin.append(src) # NOTE: these can't be freed until the GPU actually executes this command
|
||||
def _copyout(self, dest:memoryview, src:tuple[ctypes._CData, BufferSpec]):
|
||||
if src[1].image is not None:
|
||||
check(cl.clEnqueueReadImage(self.dev.queue, src[0], False, (ctypes.c_size_t * 3)(0,0,0),
|
||||
(ctypes.c_size_t * 3)(src[1].image.shape[1],src[1].image.shape[0],1), 0, 0, from_mv(dest), 0, None, None))
|
||||
else:
|
||||
check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
|
||||
check(cl.clEnqueueReadBuffer(self.dev.queue, src[0], False, 0, len(dest)*dest.itemsize, from_mv(dest), 0, None, None))
|
||||
self.dev.synchronize()
|
||||
|
||||
class CLDevice(Compiled):
|
||||
|
||||
@@ -11,6 +11,7 @@ from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.renderer.nir import IR3Renderer
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
|
||||
from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.runtime.support.system import System
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
@@ -192,8 +193,9 @@ class QCOMArgsState(HCQArgsState):
|
||||
super().__init__(buf, prg, bufs, vals=vals)
|
||||
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
|
||||
|
||||
ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None]
|
||||
ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:])
|
||||
ubos = [b for i,b in enumerate(bufs) if not isinstance(prg.buf_dtypes[i], ImageDType)]
|
||||
uavs = [(i,b) for i,b in enumerate(bufs) if isinstance(prg.buf_dtypes[i], ImageDType)]
|
||||
ibos, texs = uavs[:prg.ibo_cnt], uavs[prg.ibo_cnt:]
|
||||
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
@@ -205,19 +207,19 @@ class QCOMArgsState(HCQArgsState):
|
||||
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)])
|
||||
|
||||
def _tex(b, ibo=False):
|
||||
fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
fmt = mesa.FMT6_32_32_32_32_FLOAT if (img:=b[1].image or prg.buf_dtypes[b[0]]).itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt),
|
||||
qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr),
|
||||
qreg.a6xx_tex_const_1(width=img.shape[1], height=img.shape[0]),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=img.pitch, pitchalign=ctz(img.pitch)-6), 0, *data64_le(b[1].va_addr),
|
||||
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off)
|
||||
self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off)
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes, buf_dtypes=[]):
|
||||
self.dev: QCOMDevice = dev
|
||||
self.name, self.lib, self.NIR = name, lib, isinstance(dev.renderer, IR3Renderer)
|
||||
self.buf_dtypes, self.name, self.lib, self.NIR = buf_dtypes, name, lib, isinstance(dev.renderer, IR3Renderer)
|
||||
|
||||
if self.NIR:
|
||||
from tinygrad.runtime.support.compiler_mesa import IR3Compiler
|
||||
@@ -313,7 +315,7 @@ class QCOMTextureInfo:
|
||||
class QCOMAllocator(HCQAllocatorBase):
|
||||
def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer:
|
||||
# Recalculate real size for texture
|
||||
if opts.image is not None: size = opts.image.pitch* opts.image.shape[0]
|
||||
if opts.image is not None: size = opts.image.pitch * opts.image.shape[0]
|
||||
return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image)
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):
|
||||
|
||||
@@ -3885,7 +3885,7 @@ class Tensor(OpMixin):
|
||||
return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
|
||||
|
||||
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
|
||||
base_image_type = dtypes.imageh if getenv("FLOAT16", 0) else dtypes.imagef
|
||||
base_image_type, dtsz = (dtypes.imageh, 2) if getenv("FLOAT16", 0) else (dtypes.imagef, 4)
|
||||
|
||||
(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
|
||||
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
|
||||
@@ -3898,6 +3898,20 @@ class Tensor(OpMixin):
|
||||
w = w.pad_to(None, None, cin, None, None)
|
||||
x = x.pad_to(None, None, cin, None, None).reshape(bs, groups*cin, iy, ix)
|
||||
|
||||
# hacks for pitch alignment
|
||||
assert isinstance(ix, int) and isinstance(H, int)
|
||||
added_width = 0
|
||||
if (ix*groups*cin) % (64 // dtsz):
|
||||
added_width = round_up(ix, 64 // (dtsz * math.gcd(groups * cin, 64 // dtsz))) - ix
|
||||
ix = ix + added_width
|
||||
x = x.pad_to(None, None, None, ix)
|
||||
|
||||
added_weight = 0
|
||||
if (H*W*cin) % (64 // dtsz):
|
||||
added_weight = round_up(H, 64 // (dtsz * math.gcd(W * cin, 64 // dtsz))) - H
|
||||
H = H + added_weight
|
||||
w = w.pad_to(None, None, None, H, None)
|
||||
|
||||
# hack for non multiples of 4 on rcout
|
||||
added_output_channels = 0
|
||||
if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
|
||||
@@ -3917,13 +3931,18 @@ class Tensor(OpMixin):
|
||||
if IMAGE >= 2: x,w = x.cast(base_image_type((bs*iy, ix*groups*cin//4, 4))), w.cast(base_image_type((cout//4, H*W*cin, 4)))
|
||||
x, w = x.contiguous(), w.contiguous()
|
||||
|
||||
if added_weight: w, H = w[:, :-added_weight, ...], H - added_weight
|
||||
|
||||
# expand out
|
||||
rcin_hi, rcin_lo = (cin//4, 4) if cin >= 4 else (1, 1)
|
||||
group_shape, rcout_expand = (groups//4, 4) if cin == 1 else (groups, 1), (rcout//4, 4) if rcout >= 4 else (1, 1)
|
||||
x = x.reshape(bs, iy, ix, groups, rcin_hi, rcin_lo)
|
||||
x = x.reshape(bs, iy, -1, groups, rcin_hi, rcin_lo)
|
||||
if cin_last: w = w.reshape(cout//4, H, rcin_hi, W, 4, rcin_lo)
|
||||
else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4)
|
||||
|
||||
# undo pitch alignment hack
|
||||
if added_width: x = x[:, :, :-added_width, ...]
|
||||
|
||||
# prepare input
|
||||
x = x.permute(0,3,4,5,1,2).pad(self._resolve_pool_pads(padding,2))._pool((H,W), stride, dilation)# -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W)
|
||||
x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *group_shape, 1, 1, rcin_hi, rcin_lo, H, W)
|
||||
@@ -3931,9 +3950,20 @@ class Tensor(OpMixin):
|
||||
# prepare weights
|
||||
w = w.permute(0,4,2,5,1,3).reshape((1, 1, 1, *group_shape, *rcout_expand, rcin_hi, rcin_lo, H, W))
|
||||
|
||||
added_ox = 0
|
||||
assert isinstance(ox, int) and isinstance(cout, int)
|
||||
if (ox * cout) % (64 // dtsz):
|
||||
added_ox = round_up(ox, 64 // (dtsz * math.gcd(cout, 64 // dtsz))) - ox
|
||||
ox = ox + added_ox
|
||||
x = x.pad_to(None, None, ox, None, None, None, None, None, None, None, None)
|
||||
|
||||
# the conv!
|
||||
ret = (x*w).cast(base_image_type((bs*oy, ox*cout//4, 4)) if IMAGE >= 2 else dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)
|
||||
|
||||
if added_ox:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :-added_ox, ...]
|
||||
ox = ox - added_ox
|
||||
|
||||
# undo hack for non multiples of 4 on C.rcout
|
||||
if added_output_channels != 0:
|
||||
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]
|
||||
|
||||
Reference in New Issue
Block a user