mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
refactor ast ops dtype access [run_process_replay] (#6093)
* refactor ast ops dtype access [run_process_replay] * fix assert message
This commit is contained in:
@@ -279,16 +279,16 @@ class Kernel:
|
||||
|
||||
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
|
||||
has_cast = tc.dtype_in != tc.dtype_out
|
||||
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
|
||||
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
|
||||
|
||||
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
|
||||
if mul_op.op is not BinaryOps.MUL: return None
|
||||
|
||||
def buf_index(src: LazyOp) -> Optional[int]:
|
||||
# TODO: apply tc even if the sources are not from LOAD
|
||||
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
if src.op is BufferOps.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
|
||||
try:
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
if opt_level >= 1 and src.op is UnaryOps.CAST and src.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
|
||||
except ValueError: return None
|
||||
return None
|
||||
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
|
||||
@@ -746,7 +746,7 @@ class Kernel:
|
||||
|
||||
# group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
|
||||
mem_bytes = sum(max(x.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
|
||||
itertools.groupby([x for x in self.ast.parents if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
|
||||
key=lambda x: (x.op, x.arg.idx)))
|
||||
return Program(ansiname, src, self.opts.device, self.uops.uops, mem_estimate=mem_bytes,
|
||||
|
||||
@@ -160,20 +160,20 @@ class IndependentLowerer:
|
||||
def _to_uop(self, x:LazyOp) -> UOp:
|
||||
if x.op in BufferOps:
|
||||
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
|
||||
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
|
||||
x.dtype.base if isinstance(x.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.dtype)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is BufferOps.CONST:
|
||||
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
|
||||
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
|
||||
return valid.where(UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
|
||||
if x.arg.idx < 0:
|
||||
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype),
|
||||
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype),
|
||||
arg=(f"temp{-x.arg.idx}", x.arg.st.real_size()))
|
||||
else:
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), x.arg.idx)
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, x.dtype if isinstance(x.dtype, ImageDType) else PtrDType(x.dtype), (), x.arg.idx)
|
||||
if x.op is BufferOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
|
||||
load_dtype = x.arg.dtype.scalar()
|
||||
load_dtype = x.dtype.scalar()
|
||||
if idx.dtype == dtypes.int.vec(3):
|
||||
# this should all simplify if there's consts for id4. if not, w/e
|
||||
idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]
|
||||
|
||||
@@ -214,7 +214,7 @@ def type_verify(uops):
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
||||
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
|
||||
Reference in New Issue
Block a user