mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 22:08:08 -05:00
replace DEVICE of CONST after copy folding (#8673)
This commit is contained in:
@@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -66,6 +66,9 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -2228,6 +2231,13 @@ class TestCopyFolding(unittest.TestCase):
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
|
||||
def test_alu_after_copy(self):
|
||||
a = Tensor.ones((4,)).to("CLANG").lazydata
|
||||
b = Tensor.empty(4, device="CLANG").lazydata
|
||||
add = a+b
|
||||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
|
||||
@@ -380,8 +380,8 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# CONST doesn't need COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.cvar("x"),)), lambda x: x),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
|
||||
Reference in New Issue
Block a user