removing double reshapes was wrong (#12375)

This commit is contained in:
George Hotz
2025-10-01 09:25:35 +08:00
committed by GitHub
parent 969a1b35ca
commit a128fa0f8a

View File

@@ -16,18 +16,7 @@ ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER,
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
double_reshape = PatternMatcher([
# RESHAPE on RESHAPE is the second reshape
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"),
lambda x: x.replace(src=(x.src[0].src[0],), tag=((x.src[0].tag or ())+(x.tag or ())) or None)),
])
earliest_rewrites = double_reshape+PatternMatcher([
# non shape changing RESHAPE is NOOP
#(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
#(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0].f(Ops.NOOP, tag=x.tag)),
earliest_rewrites = PatternMatcher([
# just removing it works...
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),