mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add size of the buffer to the ptr dtype (#8322)
This commit is contained in:
@@ -356,7 +356,7 @@ class TestEqStrDType(unittest.TestCase):
|
||||
def test_strs(self):
|
||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||
self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
|
||||
self.assertEqual(str(dtypes.float32.ptr(16)), "dtypes.float.ptr(16)")
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
|
||||
@@ -49,7 +49,7 @@ def fold_expanded(ex, buf):
|
||||
rootsrc[0] if isinstance(rootsrc, tuple) else None)
|
||||
else:
|
||||
# for non image, we upcast the index pointer
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(new_src[0].dtype.local))
|
||||
new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local))
|
||||
# generate the folded new_srcs
|
||||
if is_load:
|
||||
new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src))
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, prod
|
||||
|
||||
ConstType = Union[float, int, bool]
|
||||
|
||||
@@ -38,7 +38,8 @@ class DType(metaclass=DTypeMetaClass):
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
|
||||
def ptr(self, local=False) -> PtrDType: return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1)
|
||||
def ptr(self, size=-1, local=False) -> PtrDType:
|
||||
return PtrDType(self.priority, self.itemsize, self.name, self.fmt, self.count, None, self, local, 1, size)
|
||||
def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
@@ -46,6 +47,7 @@ class PtrDType(DType):
|
||||
_base: DType
|
||||
local: bool
|
||||
v: int
|
||||
size: int = -1 # -1 is unlimited size
|
||||
@property
|
||||
def base(self): return self._base
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
@@ -53,15 +55,16 @@ class PtrDType(DType):
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
|
||||
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
def ptr(self, size=-1, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
def __repr__(self): return f"{self.base.__repr__()}.ptr({'local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
def __repr__(self):
|
||||
return f"{self.base.__repr__()}.ptr({self.size}{', local=True' if self.local else ''})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class ImageDType(PtrDType):
|
||||
shape: Tuple[int, ...] = () # shape of the Image
|
||||
def ptr(self, local=False) -> PtrDType:
|
||||
def ptr(self, size=-1, local=False) -> PtrDType:
|
||||
assert not local, "images can't be local"
|
||||
return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
@@ -131,9 +134,9 @@ class dtypes:
|
||||
|
||||
# NOTE: these are image dtypes
|
||||
@staticmethod
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, shp)
|
||||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
||||
@staticmethod
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, shp)
|
||||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, None, dtypes.float32, False, 1, prod(shp), shp)
|
||||
|
||||
default_float: ClassVar[DType] = float32
|
||||
default_int: ClassVar[DType] = int32
|
||||
|
||||
@@ -152,7 +152,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> Optional[UOp]:
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
assert x.arg[0] != -1, "fake -1 BUFFERS should not make it here"
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(), (), len(ctx.bufs)-1)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.arg[1]), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
|
||||
Reference in New Issue
Block a user