refactor ast ops dtype access [run_process_replay] (#6093)

* refactor ast ops dtype access [run_process_replay]

* fix assert message
This commit is contained in:
qazal
2024-08-16 00:13:33 +08:00
committed by GitHub
parent 9ef82e1f2b
commit 11d62668a3
3 changed files with 10 additions and 10 deletions

View File

@@ -279,16 +279,16 @@ class Kernel:
def _create_tc_opts(self, reduceop:LazyOp, tc:TensorCore, axis:int, opt_level:int) -> Optional[TensorCoreOptions]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].arg == tc.dtype_out): return None
if has_cast and not(reduceop.src[0].op is UnaryOps.CAST and reduceop.src[0].dtype == tc.dtype_out): return None
mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: return None
def buf_index(src: LazyOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
if src.op is BufferOps.LOAD and src.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
try:
if opt_level >= 1 and src.op is UnaryOps.CAST and src.arg == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
if opt_level >= 1 and src.op is UnaryOps.CAST and src.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
if (buf0:=buf_index(mul_op.src[0])) is None or (buf1:=buf_index(mul_op.src[1])) is None: return None
@@ -746,7 +746,7 @@ class Kernel:
# group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
mem_bytes = sum(max(x.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
itertools.groupby([x for x in self.ast.parents if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
key=lambda x: (x.op, x.arg.idx)))
return Program(ansiname, src, self.opts.device, self.uops.uops, mem_estimate=mem_bytes,

View File

@@ -160,20 +160,20 @@ class IndependentLowerer:
def _to_uop(self, x:LazyOp) -> UOp:
if x.op in BufferOps:
idx, valid = st_to_uops(x.arg.st, self.ridxs if x.op is BufferOps.LOAD and x.arg.idx == -1 else self.idxs,
x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.arg.dtype)
x.dtype.base if isinstance(x.dtype, ImageDType) and (not isinstance(x.arg, MemBuffer) or x.arg.idx == -1) else x.dtype)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is BufferOps.CONST:
dtype = x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype
dtype = x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype
return valid.where(UOp.const(dtype, x.arg.val), UOp.const(dtype, 0))
if x.arg.idx < 0:
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.arg.dtype.base if isinstance(x.arg.dtype, ImageDType) else x.arg.dtype),
buf = UOp(UOps.DEFINE_LOCAL, PtrDType(x.dtype.base if isinstance(x.dtype, ImageDType) else x.dtype),
arg=(f"temp{-x.arg.idx}", x.arg.st.real_size()))
else:
buf = UOp(UOps.DEFINE_GLOBAL, x.arg.dtype if isinstance(x.arg.dtype, ImageDType) else PtrDType(x.arg.dtype), (), x.arg.idx)
buf = UOp(UOps.DEFINE_GLOBAL, x.dtype if isinstance(x.dtype, ImageDType) else PtrDType(x.dtype), (), x.arg.idx)
if x.op is BufferOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[0]),)),) if len(x.src) else ()
load_dtype = x.arg.dtype.scalar()
load_dtype = x.dtype.scalar()
if idx.dtype == dtypes.int.vec(3):
# this should all simplify if there's consts for id4. if not, w/e
idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]

View File

@@ -214,7 +214,7 @@ def type_verify(uops):
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.CONST:
assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg