mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
early src delete (#3996)
* early src delete * fix bad test * fix test_linearizer
This commit is contained in:
@@ -190,7 +190,10 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.rand(tc.dims[1], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[0], dtype=tc.dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, acc_dtype=tc.dtype_out)
|
||||
realized_ast, _ = helper_realized_ast(r)
|
||||
sched = create_schedule([r.lazydata])
|
||||
realized_ast = sched[-1].ast[0]
|
||||
run_schedule(sched)
|
||||
out = r.numpy()
|
||||
k = Linearizer(realized_ast)
|
||||
k.apply_tensor_cores(1)
|
||||
k.linearize()
|
||||
@@ -198,7 +201,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len([x for x in k.applied_opts if x.op is OptOps.TC]) == 1, "tensor core opt not included"
|
||||
np_c = np_a @ np_b
|
||||
(tc_atol, tc_rtol) = (1e-2, 1e-3) if tc.dtype_out == dtypes.half else (5e-3, 1e-4)
|
||||
np.testing.assert_allclose(np_c, r.numpy(), atol=tc_atol, rtol=tc_rtol)
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.empty(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
@@ -228,11 +231,13 @@ class TestLinearizer(unittest.TestCase):
|
||||
a.assign(b.where(2, a))
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[-1].ast)
|
||||
sched_copy = sched[:]
|
||||
run_schedule(sched)
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
lin = Linearizer(*sched_copy[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_simplify_uop(self):
|
||||
def helper_test_simplify(uop, dtype, vin, arg=None):
|
||||
|
||||
Reference in New Issue
Block a user