mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 15:45:27 -05:00
replace DEVICE of CONST after copy folding (#8673)
This commit is contained in:
@@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, ops_folding
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -66,6 +66,9 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+ops_folding, ScheduleContext())
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -2228,6 +2231,13 @@ class TestCopyFolding(unittest.TestCase):
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
|
||||
def test_alu_after_copy(self):
|
||||
a = Tensor.ones((4,)).to("CLANG").lazydata
|
||||
b = Tensor.empty(4, device="CLANG").lazydata
|
||||
add = a+b
|
||||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
|
||||
Reference in New Issue
Block a user