diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3b6828688d..ae30efe400 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -200,7 +200,7 @@ to_si = PatternMatcher([ # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # CONST(VIEW) becomes VALID too, TODO: doesn't have to - (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), + (UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel @@ -444,8 +444,8 @@ do_realize = PatternMatcher([ def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[ret:=var.replace(src=())] = val.const_arg - return ret.valid(unwrap(bind.st)) + ctx.var_vals[var.replace(src=())] = val.const_arg + return var def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7f30ebe5a0..731244918b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1323,14 +1323,14 @@ merge_views = PatternMatcher([ (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), # merge unmasked const views - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) # push VIEW to parents view_left = merge_views+PatternMatcher([ # VIEW(CONST) becomes VALID - (UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)), # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),