mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix view_left for unsafe pad ops [pr] (#9478)
This commit is contained in:
@@ -1971,6 +1971,16 @@ class TestSwizzle(unittest.TestCase):
|
||||
t = a_reduce+b_reduce
|
||||
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
|
||||
|
||||
def test_unsafe_pad(self):
|
||||
x = Tensor.full((2,2), 1.0).contiguous()
|
||||
y = x*x.sum((1,)).reciprocal()
|
||||
t = y.pad(((0,1),None)).contiguous()
|
||||
swizzled = swizzle_rewrite(t.lazydata)
|
||||
sched = check_schedule(swizzled.sink(), 3)
|
||||
output_buffer = sched[-1].bufs[0]
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
zero_pm = UPat(Ops.CONST, arg=0)
|
||||
class TestView(unittest.TestCase):
|
||||
|
||||
@@ -113,7 +113,7 @@ sym = symbolic_simple+PatternMatcher([
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrouperContext:
|
||||
|
||||
@@ -992,8 +992,11 @@ merge_views = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
])
|
||||
|
||||
# view before elementwise ops
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# do not push masked view before unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
|
||||
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
|
||||
# view before elementwise ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
||||
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user