mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
cleanup kernel.py ShapeTracker replacement [pr] (#10573)
This commit is contained in:
@@ -452,11 +452,11 @@ class Kernel:
|
||||
def fixup_ast(op:UOp) -> UOp:
|
||||
ret = op.replace(src=tuple(fixup_ast(x) for x in op.src)) # noqa: F821
|
||||
if op.op in GroupOp.Buffer and op in self.bufs:
|
||||
st_uop = self.sts[self.bufs.index(op)].to_uop()
|
||||
st = self.sts[self.bufs.index(op)]
|
||||
# NOTE: if CONST got masked after applying opts, we create a new VALID
|
||||
if op.op is Ops.CONST and any(v.mask is not None for v in unwrap(st_uop.st).views): return op.view(st_uop.arg).valid()
|
||||
if op.op is Ops.CONST and any(v.mask is not None for v in st.views): return op.view(st).valid()
|
||||
# otherwise we just replace the VIEW source
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st_uop.arg),)+ret.src[1:])
|
||||
return ret.replace(src=(ret.src[0].replace(arg=st),)+ret.src[1:])
|
||||
if op.op is Ops.SINK:
|
||||
return ret.replace(arg = KernelInfo(to_function_name(self.name) if name_override is None else name_override,
|
||||
self.local_dims, self.upcasted, self.dont_use_locals))
|
||||
@@ -515,14 +515,14 @@ class Kernel:
|
||||
tuple([self.full_shape[i] if self.sts[reduce_idx].shape[i] != self.sts[reduce_idx+1].shape[i] else 1 \
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_size = st_uop.arg.real_size()
|
||||
st = ShapeTracker.from_shape(local_shape)
|
||||
local_size = st.real_size()
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), f"temp{self.reduceops.index(op)}")
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), ret)))
|
||||
local_load = UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), ret)))
|
||||
grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
st_uop = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)])).to_uop()
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st_uop.arg), UOp.store(local_buffer.view(st_uop.arg), grouped_reduce)))
|
||||
st = ShapeTracker.from_shape(tuple([1 if i in grouped_axes else a for i,a in enumerate(local_shape)]))
|
||||
return UOp(Ops.LOAD, op.dtype, (local_buffer.view(st), UOp.store(local_buffer.view(st), grouped_reduce)))
|
||||
|
||||
return ret
|
||||
fixed_ast = fixup_ast(self.ast)
|
||||
|
||||
Reference in New Issue
Block a user