mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-14 17:38:06 -05:00
removing double reshapes was wrong (#12375)
This commit is contained in:
@@ -16,18 +16,7 @@ ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD, Ops.KERNEL}
|
||||
|
||||
double_reshape = PatternMatcher([
|
||||
# RESHAPE on RESHAPE is the second reshape
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.RESHAPE),), name="x"),
|
||||
lambda x: x.replace(src=(x.src[0].src[0],), tag=((x.src[0].tag or ())+(x.tag or ())) or None)),
|
||||
])
|
||||
|
||||
earliest_rewrites = double_reshape+PatternMatcher([
|
||||
# non shape changing RESHAPE is NOOP
|
||||
#(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
||||
#(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0].f(Ops.NOOP, tag=x.tag)),
|
||||
|
||||
earliest_rewrites = PatternMatcher([
|
||||
# just removing it works...
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user