mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
remove a TODO prod(k.full_shape[k.first_upcast:]) (#11191)
IMAGE=2 test/test_ops.py works now
This commit is contained in:
@@ -53,9 +53,6 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
|
||||
# to trigger the above bug, remove prod(k.full_shape[k.first_upcast:]) from the below
|
||||
# expression and run test/test_ops.py with IMAGE=2
|
||||
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
|
||||
# this can be made much smarter
|
||||
to_upcast: list[int] = []
|
||||
@@ -64,7 +61,7 @@ def hand_coded_optimizations(k:Kernel) -> list[Opt]:
|
||||
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
|
||||
# for now skip upcasting here if there is a symbolic axis
|
||||
if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \
|
||||
prod(k.full_shape[k.first_upcast:]) * prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7:
|
||||
if DEBUG >= 4: print(f"upcasting masked axis : {axis}")
|
||||
to_upcast.append(axis)
|
||||
for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
Reference in New Issue
Block a user