mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
add dtype.ptr() [pr] (#6839)
This commit is contained in:
@@ -19,6 +19,7 @@ class DType:
|
||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||
def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self)
|
||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||
|
||||
# dependent typing?
|
||||
@@ -29,6 +30,7 @@ class ImageDType(DType):
|
||||
local: bool = False # images are never local
|
||||
def scalar(self) -> DType: return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def ptr(self) -> Union[PtrDType, ImageDType]: return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
||||
|
||||
@@ -122,7 +122,7 @@ reduceop_fusor = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp.define_global(x.dtype, ctx.bufs.index(x.arg)))])
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype.ptr(), (), ctx.bufs.index(x.arg)))])
|
||||
|
||||
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
if not AST_REWRITE: return sink
|
||||
|
||||
@@ -246,8 +246,6 @@ class UOp(MathTrait):
|
||||
@functools.lru_cache(None)
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
|
||||
Reference in New Issue
Block a user