add size of the buffer to the ptr dtype (#8322)

This commit is contained in:
George Hotz
2024-12-18 12:46:35 -08:00
committed by GitHub
parent 52243b258c
commit 6608ba316d
4 changed files with 13 additions and 10 deletions

View File

@@ -356,7 +356,7 @@ class TestEqStrDType(unittest.TestCase):
def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")
class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)

View File

@@ -49,7 +49,7 @@ def fold_expanded(ex, buf):
rootsrc[0] if isinstance(rootsrc, tuple) else None)
else:
# for non image, we upcast the index pointer
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local))
# generate the folded new_srcs
if is_load:
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv, prod
ConstType = Union[float, int, bool]
@@ -38,7 +38,8 @@ class DType(metaclass=DTypeMetaClass):
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
def ptr(self, size=-1, local=False) -> PtrDType:
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
@dataclass(frozen=True, eq=False)
@@ -46,6 +47,7 @@ class PtrDType(DType):
_base: DType
local: bool
v: int
size: int = -1 # -1 is unlimited size
@property
def base(self): return self._base
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
@@ -53,15 +55,16 @@ class PtrDType(DType):
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
def __repr__(self):
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
@dataclass(frozen=True, eq=False)
class ImageDType(PtrDType):
shape: Tuple[int, ...] = () # shape of the Image
def ptr(self, local=False) -> PtrDType:
def ptr(self, size=-1, local=False) -> PtrDType:
assert not local, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@@ -131,9 +134,9 @@ class dtypes:
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32

View File

@@ -152,7 +152,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
assert x.arg[0] != -1, "fake -1 BUFFERS should not make it here"
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(), (), len(ctx.bufs)-1)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.arg[1]), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: