replace DEVICE of CONST after copy folding (#8673)

This commit is contained in:
qazal
2025-01-19 11:33:39 -05:00
committed by GitHub
parent d957a4f108
commit 2faf8774fe
2 changed files with 14 additions and 4 deletions

View File

@@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -66,6 +66,9 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
@@ -2228,6 +2231,13 @@ class TestCopyFolding(unittest.TestCase):
run_schedule(check_schedule(b, 0, filter_sink=False))
self.assertListEqual(b.tolist(), [0, 0, 0])
def test_alu_after_copy(self):
a = Tensor.ones((4,)).to("CLANG").lazydata
b = Tensor.empty(4, device="CLANG").lazydata
add = a+b
add = schedule_graph_rewrite(add)
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
class TestTensorUOpSpec(unittest.TestCase):
def test_const_must_be_unmasked(self):
a = Tensor.ones((4, 4)).pad((2, 2))

View File

@@ -380,8 +380,8 @@ ops_folding = symbolic_simple+PatternMatcher([
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# CONST doesn't need COPY
(UPat(Ops.COPY, src=(UPat(), UPat.cvar("x"),)), lambda x: x),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),