add dtype.ptr() [pr] (#6839)

This commit is contained in:
qazal
2024-10-02 15:03:05 +08:00
committed by GitHub
parent be12409b51
commit 29363fb85e
3 changed files with 3 additions and 3 deletions

View File

@@ -19,6 +19,7 @@ class DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# dependent typing?
@@ -29,6 +30,7 @@ class ImageDType(DType):
local: bool = False # images are never local
def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# @dataclass(frozen=True, init=False, repr=False, eq=False)

View File

@@ -122,7 +122,7 @@ reduceop_fusor = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp.define_global(x.dtype, ctx.bufs.index(x.arg)))])
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype.ptr(), (), ctx.bufs.index(x.arg)))])
def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
if not AST_REWRITE: return sink

View File

@@ -246,8 +246,6 @@ class UOp(MathTrait):
@functools.lru_cache(None)
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)