mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
lil symbolic changes [pr] (#6841)
This commit is contained in:
@@ -668,8 +668,8 @@ reducer = PatternMatcher([
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"),
|
||||
lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
|
||||
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
|
||||
@@ -236,6 +236,7 @@ class UOp(MathTrait):
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg)
|
||||
@@ -687,9 +688,9 @@ spec = PatternMatcher([
|
||||
])
|
||||
|
||||
def type_verify(uops:List[UOp]):
|
||||
for u in uops:
|
||||
for i,u in enumerate(uops):
|
||||
chk = cast(bool, spec.rewrite(u))
|
||||
assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
|
||||
simple_pm = PatternMatcher([
|
||||
# bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
|
||||
@@ -28,7 +28,7 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu
|
||||
# always merge 1
|
||||
if s == 1: continue
|
||||
# merge last dim with this dim if merging or strides matched
|
||||
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0)
|
||||
if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0)
|
||||
else: ret.append((s, st, s if st != 0 else 0))
|
||||
# merge this dim to next dim if size is 1
|
||||
merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1
|
||||
@@ -300,7 +300,7 @@ class View:
|
||||
# TODO: this <= and != is for symbolic!?
|
||||
while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, 0)) > 0:
|
||||
strides.append(new_stride)
|
||||
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
||||
if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0)
|
||||
if acc != merged_dim: break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
|
||||
Reference in New Issue
Block a user