diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 4864b707a1..ccdf181e7a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -452,11 +452,11 @@ class Kernel: def fixup_ast(op:UOp) -> UOp: ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821 if op.op in GroupOp.Buffer and op in self.bufs: - st_uop = self.sts[self.bufs.index(op)].to_uop() + st = self.sts[self.bufs.index(op)] # NOTE: if CONST got masked after applying opts, we create a new VALID - if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.view(st_uop.arg).valid() + if op.op is Ops.CONST and any(v.mask is not None for v in st.views): return op.view(st).valid() # otherwise we just replace the VIEW source - return ret.replace(src=(ret.src[0].replace(arg=st_uop.arg),)+ret.src[1:]) + return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:]) if op.op is Ops.SINK: return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override, self.local_dims, self.upcasted, self.dont_use_locals)) @@ -515,14 +515,14 @@ class Kernel: tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \ for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) - st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_size = st_uop.arg.real_size() + st = ShapeTracker.from_shape(local_shape) + local_size = st.real_size() local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}") - local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), ret))) + local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), ret))) grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) if op is self.reduceops[-1]: return grouped_reduce - st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop() - return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), grouped_reduce))) + st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])) + return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), grouped_reduce))) return ret fixed_ast = fixup_ast(self.ast)