s/self.shape_len - self.upcasted/self.first_upcast (#5802)

missed the one with spaces.
[run_process_replay]
This commit is contained in:
chenyu
2024-07-29 18:23:42 -04:00
committed by GitHub
parent 1a19751902
commit 22e7289fe0

View File

@@ -468,7 +468,7 @@ class Kernel:
self.reshape_and_permute(None, tuple(permute))
elif opt.op is OptOps.PADTO:
check(not self.vars, "does not work with symbolic shape")
check(axis < self.shape_len - self.upcasted, "cannot pad upcasted")
check(axis < self.first_upcast, "cannot pad upcasted")
# ok to pad SUM if all parent ops have f(0) = 0
if self.first_reduce <= axis:
check((r:=cast(LazyOp, self.reduceop)).op is ReduceOps.SUM and \
@@ -551,7 +551,7 @@ class Kernel:
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
# to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
@@ -561,7 +561,7 @@ class Kernel:
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \
prod(self.full_shape[self.shape_len - self.upcasted:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7:
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
@@ -712,7 +712,7 @@ class Kernel:
if self.group_for_reduces:
start = LazyOp(op.op, tuple(fixup_ast(x, apply_to_st) for x in op.src), arg)
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
local_buffer = MemBuffer(-1, start.dtype, ShapeTracker.from_shape(local_shape))
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)