diff --git a/test/device/test_qcom.py b/test/device/test_qcom.py deleted file mode 100644 index 827a364d1d..0000000000 --- a/test/device/test_qcom.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -import unittest -from tinygrad.device import Device, BufferSpec -from tinygrad.dtype import dtypes - -@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run") -class TestQcom(unittest.TestCase): - def test_image_pitch(self): - dev = Device["QCOM"] - - def __validate(imgdt, expected_pitch): - img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt)) - pitch = img.texture_info.pitch - assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}" - dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options) - - # Match opencl pitches for perf - __validate(dtypes.imageh((1, 201)), 0x680) - __validate(dtypes.imageh((16, 216)), 0x700) - __validate(dtypes.imageh((16, 9)), 0x80) - __validate(dtypes.imageh((48, 64)), 0x200) - __validate(dtypes.imageh((32, 128)), 0x400) - __validate(dtypes.imageh((96, 128)), 0x400) - __validate(dtypes.imageh((64, 256)), 0x840) - __validate(dtypes.imageh((64, 9)), 0x80) - __validate(dtypes.imageh((192, 256)), 0x840) - __validate(dtypes.imageh((64, 768)), 0x1840) - __validate(dtypes.imageh((256, 49)), 0x1C0) - __validate(dtypes.imageh((128, 9)), 0x80) - __validate(dtypes.imageh((16, 1024)), 0x2080) - __validate(dtypes.imageh((64, 512)), 0x1040) - __validate(dtypes.imageh((16, 512)), 0x1080) - __validate(dtypes.imageh((132, 64)), 0x200) - __validate(dtypes.imageh((4, 512)), 0x1200) - __validate(dtypes.imageh((8, 512)), 0x1100) - __validate(dtypes.imageh((128, 128)), 0x400) - __validate(dtypes.imageh((32, 512)), 0x1040) - __validate(dtypes.imageh((26, 64)), 0x200) - __validate(dtypes.imageh((32, 516)), 0x1040) - __validate(dtypes.imageh((32, 1024)), 0x2040) - __validate(dtypes.imageh((16, 2048)), 0x4080) - __validate(dtypes.imageh((8, 2048)), 0x4100) - __validate(dtypes.imageh((4, 4096)), 0x8200) - - __validate(dtypes.imagef((16, 49)), 0x380) - __validate(dtypes.imagef((16, 1024)), 0x4080) - __validate(dtypes.imagef((256, 64)), 0x400) - __validate(dtypes.imagef((64, 512)), 0x2040) - __validate(dtypes.imagef((16, 512)), 0x2080) - __validate(dtypes.imagef((132, 64)), 0x400) - __validate(dtypes.imagef((4, 512)), 0x2200) - __validate(dtypes.imagef((4, 16)), 0x200) - __validate(dtypes.imagef((2, 16)), 0x400) - __validate(dtypes.imagef((8, 512)), 0x2100) - __validate(dtypes.imagef((12, 64)), 0x400) - __validate(dtypes.imagef((3, 32)), 0x400) - __validate(dtypes.imagef((128, 128)), 0x840) - __validate(dtypes.imagef((32, 512)), 0x2040) - __validate(dtypes.imagef((8, 3072)), 0xC100) - __validate(dtypes.imagef((4, 2048)), 0x8200) - __validate(dtypes.imagef((4, 1024)), 0x4200) - __validate(dtypes.imagef((4, 4096)), 0x10200) - __validate(dtypes.imagef((10, 384)), 0x1900) - __validate(dtypes.imagef((24, 64)), 0x400) - __validate(dtypes.imagef((128, 12)), 0xC0) - __validate(dtypes.imagef((10, 24)), 0x200) - __validate(dtypes.imagef((1, 129)), 0x840) - __validate(dtypes.imagef((1, 32)), 0x200) - __validate(dtypes.imagef((1, 64)), 0x400) - __validate(dtypes.imagef((1, 1239)), 0x4D80) - __validate(dtypes.imagef((1, 1)), 0x40) - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 1ec95939be..ecf37db581 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -44,6 +44,66 @@ class TestImageCopy(unittest.TestCase): @unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported") class TestImageDType(unittest.TestCase): + def test_image_pitch(self): + def __validate(imgdt, expected_pitch): + assert imgdt.pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{imgdt.pitch:X}, expected 0x{expected_pitch:X}" + + # Match opencl pitches for perf + __validate(dtypes.imageh((1, 201)), 0x680) + __validate(dtypes.imageh((16, 216)), 0x700) + __validate(dtypes.imageh((16, 9)), 0x80) + __validate(dtypes.imageh((48, 64)), 0x200) + __validate(dtypes.imageh((32, 128)), 0x400) + __validate(dtypes.imageh((96, 128)), 0x400) + __validate(dtypes.imageh((64, 256)), 0x840) + __validate(dtypes.imageh((64, 9)), 0x80) + __validate(dtypes.imageh((192, 256)), 0x840) + __validate(dtypes.imageh((64, 768)), 0x1840) + __validate(dtypes.imageh((256, 49)), 0x1C0) + __validate(dtypes.imageh((128, 9)), 0x80) + __validate(dtypes.imageh((16, 1024)), 0x2080) + __validate(dtypes.imageh((64, 512)), 0x1040) + __validate(dtypes.imageh((16, 512)), 0x1080) + __validate(dtypes.imageh((132, 64)), 0x200) + __validate(dtypes.imageh((4, 512)), 0x1200) + __validate(dtypes.imageh((8, 512)), 0x1100) + __validate(dtypes.imageh((128, 128)), 0x400) + __validate(dtypes.imageh((32, 512)), 0x1040) + __validate(dtypes.imageh((26, 64)), 0x200) + __validate(dtypes.imageh((32, 516)), 0x1040) + __validate(dtypes.imageh((32, 1024)), 0x2040) + __validate(dtypes.imageh((16, 2048)), 0x4080) + __validate(dtypes.imageh((8, 2048)), 0x4100) + __validate(dtypes.imageh((4, 4096)), 0x8200) + + __validate(dtypes.imagef((16, 49)), 0x380) + __validate(dtypes.imagef((16, 1024)), 0x4080) + __validate(dtypes.imagef((256, 64)), 0x400) + __validate(dtypes.imagef((64, 512)), 0x2040) + __validate(dtypes.imagef((16, 512)), 0x2080) + __validate(dtypes.imagef((132, 64)), 0x400) + __validate(dtypes.imagef((4, 512)), 0x2200) + __validate(dtypes.imagef((4, 16)), 0x200) + __validate(dtypes.imagef((2, 16)), 0x400) + __validate(dtypes.imagef((8, 512)), 0x2100) + __validate(dtypes.imagef((12, 64)), 0x400) + __validate(dtypes.imagef((3, 32)), 0x400) + __validate(dtypes.imagef((128, 128)), 0x840) + __validate(dtypes.imagef((32, 512)), 0x2040) + __validate(dtypes.imagef((8, 3072)), 0xC100) + __validate(dtypes.imagef((4, 2048)), 0x8200) + __validate(dtypes.imagef((4, 1024)), 0x4200) + __validate(dtypes.imagef((4, 4096)), 0x10200) + __validate(dtypes.imagef((10, 384)), 0x1900) + __validate(dtypes.imagef((24, 64)), 0x400) + __validate(dtypes.imagef((128, 12)), 0xC0) + __validate(dtypes.imagef((10, 24)), 0x200) + __validate(dtypes.imagef((1, 129)), 0x840) + __validate(dtypes.imagef((1, 32)), 0x200) + __validate(dtypes.imagef((1, 64)), 0x400) + __validate(dtypes.imagef((1, 1239)), 0x4D80) + __validate(dtypes.imagef((1, 1)), 0x40) + def test_image_and_back(self): data = Tensor.randn(9*27*4).realize() tst = data.numpy() diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 454d785738..7b37a321c2 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Final, ClassVar, Callable, Literal import math, struct, ctypes, functools from dataclasses import dataclass, fields -from tinygrad.helpers import getenv, prod +from tinygrad.helpers import getenv, prod, round_up, next_power2 from enum import Enum, auto class InvalidTypeMetaClass(type): @@ -101,6 +101,15 @@ class ImageDType(PtrDType): assert addrspace == AddrSpace.GLOBAL, "images can't be local" return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '') + @property + def pitch(self): + imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize)) + pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6 + align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2)) + + granularity = 128 if self.itemsize == 4 else 256 + pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0 + return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add class dtypes: @staticmethod diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5e08ef2fa4..aca12c4d3b 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -50,6 +50,7 @@ def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fs def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt def round_down(num:int, amt:int) -> int: return -round_up(-num, amt) +def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length() # cstyle div and mod def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0 def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index 6eaf1fa7b9..ec62d829c0 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -22,7 +22,7 @@ class HCQGraph(MultiGraphRunner): for (j,i), input_idx in self.input_replace.items(): x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64)) - self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable + self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, image=self.hcq_bufs[j][i].image) # Create fake buffer with variable # Allocate kernel args. kernargs_size: dict[Compiled, int] = collections.defaultdict(int) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index a27f9b98a1..deec87f582 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -10,7 +10,7 @@ from tinygrad.runtime.ops_cl import CLCompiler, CLDevice from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.renderer.nir import IR3Renderer from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing -from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC +from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC from tinygrad.runtime.support.system import System if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import @@ -25,7 +25,7 @@ def _qreg_exec(__reg, __val=0, **kwargs): return __val qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'}) -def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length() +def ctz(v): return (v & -v).bit_length() - 1 def parity(val: int): for i in range(4,1,-1): val ^= val >> (1 << i) @@ -192,9 +192,8 @@ class QCOMArgsState(HCQArgsState): super().__init__(buf, prg, bufs, vals=vals) ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size) - ubos, uavs = [b for b in bufs if b.texture_info is None], [b for b in bufs if b.texture_info is not None] + ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None] ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:]) - for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little') if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) @@ -205,8 +204,15 @@ class QCOMArgsState(HCQArgsState): for i, b in enumerate(ubos): self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=prg.buf_offs[i]) for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)]) - self.bind_sints_to_buf(*flatten([b.texture_info.desc + ([0] * 8) for b in texs]), buf=self.buf, fmt='I', offset=prg.tex_off) - self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off) + def _tex(b, ibo=False): + fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT + return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt), + qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]), + qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr), + qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0] + + self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off) + self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off) class QCOMProgram(HCQProgram): def __init__(self, dev: QCOMDevice, name: str, lib: bytes): @@ -305,28 +311,10 @@ class QCOMTextureInfo: self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo class QCOMAllocator(HCQAllocatorBase): - def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: + def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer: # Recalculate real size for texture - if options.image is not None: - imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize)) - pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6 - align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2)) - - granularity = 128 if options.image.itemsize == 4 else 256 - pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0 - pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add - size = pitch * imgh - - buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size) - - if options.image is not None: - tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT - desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh), - qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0, - *data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)] - - buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]]) - return buf + if opts.image is not None: size = opts.image.pitch* opts.image.shape[0] + return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image) def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0): with cpu_profile(prof_text, self.dev.device, is_copy=True): @@ -335,13 +323,13 @@ class QCOMAllocator(HCQAllocatorBase): src_off, dest_off = src_off+src_stride, dest_off+dest_stride def _copyin(self, dest:HCQBuffer, src:memoryview): - stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch) + stride, pitch = (dest.image.shape[1] * 4 * dest.image.itemsize, dest.image.pitch) if dest.image else (src.nbytes, src.nbytes) self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}") def _copyout(self, dest:memoryview, src:HCQBuffer): self.dev.synchronize() - stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch) + stride, pitch = (src.image.shape[1] * 4 * src.image.itemsize, src.image.pitch) if src.image else (src.size, src.size) self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY") def _as_buffer(self, src:HCQBuffer) -> memoryview: @@ -388,7 +376,7 @@ class QCOMDevice(HCQCompiled): super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal, functools.partial(QCOMComputeQueue, self), None) - def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer: + def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False, **kwargs) -> HCQBuffer: flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP if uncached: flags |= flag("KGSL_CACHEMODE", kgsl.KGSL_CACHEMODE_UNCACHED) @@ -396,15 +384,15 @@ class QCOMDevice(HCQCompiled): va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000) if fill_zeroes: ctypes.memset(va_addr, 0, size) - return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self) + return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self, **kwargs) - def _gpu_map(self, ptr:int, size:int) -> HCQBuffer: + def _gpu_map(self, ptr:int, size:int, **kwargs) -> HCQBuffer: ptr_aligned, size_aligned = (ptr & ~0xfff), round_up(size + (ptr & 0xfff), 0x1000) try: - mapinfo = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR) - return HCQBuffer(mapinfo.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mapinfo, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self) + mi = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR) + return HCQBuffer(mi.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mi, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs) except OSError as e: - if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self) + if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs) raise RuntimeError("Failed to map external pointer to GPU memory") from e def _gpu_free(self, mem:HCQBuffer): diff --git a/tinygrad/runtime/support/hcq.py b/tinygrad/runtime/support/hcq.py index f64dbd6188..a3bfbe1315 100644 --- a/tinygrad/runtime/support/hcq.py +++ b/tinygrad/runtime/support/hcq.py @@ -8,6 +8,7 @@ from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEve from tinygrad.uop.ops import sym_infer, sint, UOp from tinygrad.runtime.autogen import libc from tinygrad.runtime.support.memory import BumpAllocator +from tinygrad.dtype import ImageDType class MMIOInterface: def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt @@ -455,14 +456,14 @@ class HCQCompiled(Compiled, Generic[SignalType]): if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini() class HCQBuffer: - def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None, + def __init__(self, va_addr:sint, size:int, image:ImageDType|None=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None, owner:HCQCompiled|None=None): - self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view + self.va_addr, self.size, self.image, self.meta, self._base, self.view = va_addr, size, image, meta, _base, view self._devs, self.owner = ([owner] if owner is not None else []), owner self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer: - return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta, + return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, image=self.image, meta=self.meta, _base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None)) def cpu_view(self) -> MMIOInterface: