diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index eef03c56e6..5f077c5ccb 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -668,8 +668,8 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load), ]) -no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE), name="x"), - lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)]) +no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR), + name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)]) # *** uop graph *** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index d07b243fe8..354f78dd75 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -236,6 +236,7 @@ class UOp(MathTrait): if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @staticmethod + @functools.lru_cache(None) def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @staticmethod def define_global(dtype:DType, arg): return UOp(UOps.DEFINE_GLOBAL, dtype if isinstance(dtype, ImageDType) else PtrDType(dtype), (), arg) @@ -687,9 +688,9 @@ spec = PatternMatcher([ ]) def type_verify(uops:List[UOp]): - for u in uops: + for i,u in enumerate(uops): chk = cast(bool, spec.rewrite(u)) - assert chk is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}" + assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}" simple_pm = PatternMatcher([ # bool MUL is AND, ADD/MAX is OR. prevents other rules to rewrite bool ADD/MUL incorrectly diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 7d073b4bd4..a72c36bcb4 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -28,7 +28,7 @@ def _merge_dims(shape:Tuple[int, ...], strides:Tuple[int, ...], mask:Optional[Tu # always merge 1 if s == 1: continue # merge last dim with this dim if merging or strides matched - if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st else 0) + if merging or last_st == s * st: ret[-1] = (last_s * s, st, (s if merging else last_pre_expand_s * s) if st != 0 else 0) else: ret.append((s, st, s if st != 0 else 0)) # merge this dim to next dim if size is 1 merging = (mask[i][1] - mask[i][0] == 1) if mask is not None else s == 1 @@ -300,7 +300,7 @@ class View: # TODO: this <= and != is for symbolic!? while acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, 0)) > 0: strides.append(new_stride) - if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0) + if new_dim != 1: new_stride *= (new_dim if (acc := acc * new_dim) < real_dim else 0) if acc != merged_dim: break else: strides += [0,] * (len(new_shape) - len(strides))