lil symbolic changes [pr] (#6841)

This commit is contained in:
George Hotz
2024-10-02 10:56:22 +08:00
committed by GitHub
parent 100ce7a684
commit 567e10efcb
3 changed files with 7 additions and 6 deletions

View File

@@ -668,8 +668,8 @@ reducer = PatternMatcher([
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load),
])
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
# *** uop graph ***

View File

@@ -236,6 +236,7 @@ class UOp(MathTrait):
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
@functools.lru_cache(None)
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
@@ -687,9 +688,9 @@ spec = PatternMatcher([
])
def type_verify(uops:List[UOp]):
for u in uops:
for i,u in enumerate(uops):
chk = cast(bool, spec.rewrite(u))
assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
simple_pm = PatternMatcher([
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly

View File

@@ -28,7 +28,7 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
# always merge 1
if s == 1: continue
# merge last dim with this dim if merging or strides matched
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
else: ret.append((s, st, s if st != 0 else 0))
# merge this dim to next dim if size is 1
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
@@ -300,7 +300,7 @@ class View:
# TODO: this <= and != is for symbolic!?
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, 0)) > 0:
strides.append(new_stride)
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
if acc != merged_dim: break
else:
strides += [0,] * (len(new_shape) - len(strides))