refactor image pitch (#13928)

This commit is contained in:
Christopher Milan
2025-12-31 10:22:38 -08:00
committed by GitHub
parent 051fe6c8bc
commit 13973e4dea
7 changed files with 99 additions and 114 deletions

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python
import unittest
from tinygrad.device import Device, BufferSpec
from tinygrad.dtype import dtypes
@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run")
class TestQcom(unittest.TestCase):
def test_image_pitch(self):
dev = Device["QCOM"]
def __validate(imgdt, expected_pitch):
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt))
pitch = img.texture_info.pitch
assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}"
dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options)
# Match opencl pitches for perf
__validate(dtypes.imageh((1, 201)), 0x680)
__validate(dtypes.imageh((16, 216)), 0x700)
__validate(dtypes.imageh((16, 9)), 0x80)
__validate(dtypes.imageh((48, 64)), 0x200)
__validate(dtypes.imageh((32, 128)), 0x400)
__validate(dtypes.imageh((96, 128)), 0x400)
__validate(dtypes.imageh((64, 256)), 0x840)
__validate(dtypes.imageh((64, 9)), 0x80)
__validate(dtypes.imageh((192, 256)), 0x840)
__validate(dtypes.imageh((64, 768)), 0x1840)
__validate(dtypes.imageh((256, 49)), 0x1C0)
__validate(dtypes.imageh((128, 9)), 0x80)
__validate(dtypes.imageh((16, 1024)), 0x2080)
__validate(dtypes.imageh((64, 512)), 0x1040)
__validate(dtypes.imageh((16, 512)), 0x1080)
__validate(dtypes.imageh((132, 64)), 0x200)
__validate(dtypes.imageh((4, 512)), 0x1200)
__validate(dtypes.imageh((8, 512)), 0x1100)
__validate(dtypes.imageh((128, 128)), 0x400)
__validate(dtypes.imageh((32, 512)), 0x1040)
__validate(dtypes.imageh((26, 64)), 0x200)
__validate(dtypes.imageh((32, 516)), 0x1040)
__validate(dtypes.imageh((32, 1024)), 0x2040)
__validate(dtypes.imageh((16, 2048)), 0x4080)
__validate(dtypes.imageh((8, 2048)), 0x4100)
__validate(dtypes.imageh((4, 4096)), 0x8200)
__validate(dtypes.imagef((16, 49)), 0x380)
__validate(dtypes.imagef((16, 1024)), 0x4080)
__validate(dtypes.imagef((256, 64)), 0x400)
__validate(dtypes.imagef((64, 512)), 0x2040)
__validate(dtypes.imagef((16, 512)), 0x2080)
__validate(dtypes.imagef((132, 64)), 0x400)
__validate(dtypes.imagef((4, 512)), 0x2200)
__validate(dtypes.imagef((4, 16)), 0x200)
__validate(dtypes.imagef((2, 16)), 0x400)
__validate(dtypes.imagef((8, 512)), 0x2100)
__validate(dtypes.imagef((12, 64)), 0x400)
__validate(dtypes.imagef((3, 32)), 0x400)
__validate(dtypes.imagef((128, 128)), 0x840)
__validate(dtypes.imagef((32, 512)), 0x2040)
__validate(dtypes.imagef((8, 3072)), 0xC100)
__validate(dtypes.imagef((4, 2048)), 0x8200)
__validate(dtypes.imagef((4, 1024)), 0x4200)
__validate(dtypes.imagef((4, 4096)), 0x10200)
__validate(dtypes.imagef((10, 384)), 0x1900)
__validate(dtypes.imagef((24, 64)), 0x400)
__validate(dtypes.imagef((128, 12)), 0xC0)
__validate(dtypes.imagef((10, 24)), 0x200)
__validate(dtypes.imagef((1, 129)), 0x840)
__validate(dtypes.imagef((1, 32)), 0x200)
__validate(dtypes.imagef((1, 64)), 0x400)
__validate(dtypes.imagef((1, 1239)), 0x4D80)
__validate(dtypes.imagef((1, 1)), 0x40)
if __name__ == "__main__":
unittest.main()

View File

@@ -44,6 +44,66 @@ class TestImageCopy(unittest.TestCase):
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageDType(unittest.TestCase):
def test_image_pitch(self):
def __validate(imgdt, expected_pitch):
assert imgdt.pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{imgdt.pitch:X}, expected 0x{expected_pitch:X}"
# Match opencl pitches for perf
__validate(dtypes.imageh((1, 201)), 0x680)
__validate(dtypes.imageh((16, 216)), 0x700)
__validate(dtypes.imageh((16, 9)), 0x80)
__validate(dtypes.imageh((48, 64)), 0x200)
__validate(dtypes.imageh((32, 128)), 0x400)
__validate(dtypes.imageh((96, 128)), 0x400)
__validate(dtypes.imageh((64, 256)), 0x840)
__validate(dtypes.imageh((64, 9)), 0x80)
__validate(dtypes.imageh((192, 256)), 0x840)
__validate(dtypes.imageh((64, 768)), 0x1840)
__validate(dtypes.imageh((256, 49)), 0x1C0)
__validate(dtypes.imageh((128, 9)), 0x80)
__validate(dtypes.imageh((16, 1024)), 0x2080)
__validate(dtypes.imageh((64, 512)), 0x1040)
__validate(dtypes.imageh((16, 512)), 0x1080)
__validate(dtypes.imageh((132, 64)), 0x200)
__validate(dtypes.imageh((4, 512)), 0x1200)
__validate(dtypes.imageh((8, 512)), 0x1100)
__validate(dtypes.imageh((128, 128)), 0x400)
__validate(dtypes.imageh((32, 512)), 0x1040)
__validate(dtypes.imageh((26, 64)), 0x200)
__validate(dtypes.imageh((32, 516)), 0x1040)
__validate(dtypes.imageh((32, 1024)), 0x2040)
__validate(dtypes.imageh((16, 2048)), 0x4080)
__validate(dtypes.imageh((8, 2048)), 0x4100)
__validate(dtypes.imageh((4, 4096)), 0x8200)
__validate(dtypes.imagef((16, 49)), 0x380)
__validate(dtypes.imagef((16, 1024)), 0x4080)
__validate(dtypes.imagef((256, 64)), 0x400)
__validate(dtypes.imagef((64, 512)), 0x2040)
__validate(dtypes.imagef((16, 512)), 0x2080)
__validate(dtypes.imagef((132, 64)), 0x400)
__validate(dtypes.imagef((4, 512)), 0x2200)
__validate(dtypes.imagef((4, 16)), 0x200)
__validate(dtypes.imagef((2, 16)), 0x400)
__validate(dtypes.imagef((8, 512)), 0x2100)
__validate(dtypes.imagef((12, 64)), 0x400)
__validate(dtypes.imagef((3, 32)), 0x400)
__validate(dtypes.imagef((128, 128)), 0x840)
__validate(dtypes.imagef((32, 512)), 0x2040)
__validate(dtypes.imagef((8, 3072)), 0xC100)
__validate(dtypes.imagef((4, 2048)), 0x8200)
__validate(dtypes.imagef((4, 1024)), 0x4200)
__validate(dtypes.imagef((4, 4096)), 0x10200)
__validate(dtypes.imagef((10, 384)), 0x1900)
__validate(dtypes.imagef((24, 64)), 0x400)
__validate(dtypes.imagef((128, 12)), 0xC0)
__validate(dtypes.imagef((10, 24)), 0x200)
__validate(dtypes.imagef((1, 129)), 0x840)
__validate(dtypes.imagef((1, 32)), 0x200)
__validate(dtypes.imagef((1, 64)), 0x400)
__validate(dtypes.imagef((1, 1239)), 0x4D80)
__validate(dtypes.imagef((1, 1)), 0x40)
def test_image_and_back(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Final, ClassVar, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod
from tinygrad.helpers import getenv, prod, round_up, next_power2
from enum import Enum, auto
class InvalidTypeMetaClass(type):
@@ -101,6 +101,15 @@ class ImageDType(PtrDType):
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@property
def pitch(self):
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
granularity = 128 if self.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
class dtypes:
@staticmethod

View File

@@ -50,6 +50,7 @@ def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fs
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
# cstyle div and mod
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y

View File

@@ -22,7 +22,7 @@ class HCQGraph(MultiGraphRunner):
for (j,i), input_idx in self.input_replace.items():
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, image=self.hcq_bufs[j][i].image) # Create fake buffer with variable
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)

View File

@@ -10,7 +10,7 @@ from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.renderer.nir import IR3Renderer
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC
from tinygrad.runtime.support.system import System
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
@@ -25,7 +25,7 @@ def _qreg_exec(__reg, __val=0, **kwargs):
return __val
qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'})
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
def ctz(v): return (v & -v).bit_length() - 1
def parity(val: int):
for i in range(4,1,-1): val ^= val >> (1 << i)
@@ -192,9 +192,8 @@ class QCOMArgsState(HCQArgsState):
super().__init__(buf, prg, bufs, vals=vals)
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
ubos, uavs = [b for b in bufs if b.texture_info is None], [b for b in bufs if b.texture_info is not None]
ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None]
ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:])
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
@@ -205,8 +204,15 @@ class QCOMArgsState(HCQArgsState):
for i, b in enumerate(ubos): self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=prg.buf_offs[i])
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)])
self.bind_sints_to_buf(*flatten([b.texture_info.desc + ([0] * 8) for b in texs]), buf=self.buf, fmt='I', offset=prg.tex_off)
self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off)
def _tex(b, ibo=False):
fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt),
qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]),
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr),
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0]
self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off)
self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off)
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
@@ -305,28 +311,10 @@ class QCOMTextureInfo:
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
class QCOMAllocator(HCQAllocatorBase):
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer:
# Recalculate real size for texture
if options.image is not None:
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
granularity = 128 if options.image.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
size = pitch * imgh
buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
if options.image is not None:
tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
return buf
if opts.image is not None: size = opts.image.pitch* opts.image.shape[0]
return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image)
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):
with cpu_profile(prof_text, self.dev.device, is_copy=True):
@@ -335,13 +323,13 @@ class QCOMAllocator(HCQAllocatorBase):
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
def _copyin(self, dest:HCQBuffer, src:memoryview):
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
stride, pitch = (dest.image.shape[1] * 4 * dest.image.itemsize, dest.image.pitch) if dest.image else (src.nbytes, src.nbytes)
self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}")
def _copyout(self, dest:memoryview, src:HCQBuffer):
self.dev.synchronize()
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
stride, pitch = (src.image.shape[1] * 4 * src.image.itemsize, src.image.pitch) if src.image else (src.size, src.size)
self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY")
def _as_buffer(self, src:HCQBuffer) -> memoryview:
@@ -388,7 +376,7 @@ class QCOMDevice(HCQCompiled):
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
functools.partial(QCOMComputeQueue, self), None)
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False, **kwargs) -> HCQBuffer:
flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
if uncached: flags |= flag("KGSL_CACHEMODE", kgsl.KGSL_CACHEMODE_UNCACHED)
@@ -396,15 +384,15 @@ class QCOMDevice(HCQCompiled):
va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000)
if fill_zeroes: ctypes.memset(va_addr, 0, size)
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self)
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self, **kwargs)
def _gpu_map(self, ptr:int, size:int) -> HCQBuffer:
def _gpu_map(self, ptr:int, size:int, **kwargs) -> HCQBuffer:
ptr_aligned, size_aligned = (ptr & ~0xfff), round_up(size + (ptr & 0xfff), 0x1000)
try:
mapinfo = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
return HCQBuffer(mapinfo.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mapinfo, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
mi = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
return HCQBuffer(mi.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mi, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
except OSError as e:
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
raise RuntimeError("Failed to map external pointer to GPU memory") from e
def _gpu_free(self, mem:HCQBuffer):

View File

@@ -8,6 +8,7 @@ from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEve
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.dtype import ImageDType
class MMIOInterface:
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
@@ -455,14 +456,14 @@ class HCQCompiled(Compiled, Generic[SignalType]):
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
class HCQBuffer:
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
def __init__(self, va_addr:sint, size:int, image:ImageDType|None=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
owner:HCQCompiled|None=None):
self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view
self.va_addr, self.size, self.image, self.meta, self._base, self.view = va_addr, size, image, meta, _base, view
self._devs, self.owner = ([owner] if owner is not None else []), owner
self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices
def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer:
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta,
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, image=self.image, meta=self.meta,
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
def cpu_view(self) -> MMIOInterface: