From 9eeba968cdcdee4a1cf6af5b8c555d31e86e8380 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 21 Nov 2023 12:02:31 -0500 Subject: [PATCH] fix the variable arg order (#2382) --- test/test_symbolic_jit.py | 16 +++++++++++++++- test/test_symbolic_ops.py | 15 ++++++++++++++- tinygrad/codegen/linearizer.py | 2 +- tinygrad/ops.py | 5 +++-- tinygrad/runtime/ops_metal.py | 2 +- 5 files changed, 34 insertions(+), 6 deletions(-) diff --git a/test/test_symbolic_jit.py b/test/test_symbolic_jit.py index ac987762e3..bff3c51d33 100644 --- a/test/test_symbolic_jit.py +++ b/test/test_symbolic_jit.py @@ -122,7 +122,7 @@ class TestSymbolicJit(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert len(jf.jit_cache) == 1 - def test_two_vars_plus1(self): + def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() jf = TinyJit(f) for i in range(1, 5): @@ -136,6 +136,20 @@ class TestSymbolicJit(unittest.TestCase): np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) assert len(jf.jit_cache) == 1 + def test_two_vars_plus1_ji(self): + def f(a, b): return (a@b+1).realize() + jf = TinyJit(f) + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(j, 3) + b = Tensor.rand(3, i) + symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + assert len(jf.jit_cache) == 1 + def test_jit_symbolic_shape_mismatch(self): @TinyJit def add(a, b): return (a+b).realize() diff --git a/test/test_symbolic_ops.py b/test/test_symbolic_ops.py index b36fb2257c..73036b7740 100644 --- a/test/test_symbolic_ops.py +++ b/test/test_symbolic_ops.py @@ -98,7 +98,7 @@ class TestSymbolicOps(unittest.TestCase): expected = f(a, b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) - def test_two_vars_plus1(self): + def test_two_vars_plus1_ij(self): def f(a, b): return (a@b+1).realize() for i in range(1, 5): for j in range(1, 5): @@ -110,6 +110,19 @@ class TestSymbolicOps(unittest.TestCase): expected = f(a, b).numpy() np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_two_vars_plus1_ji(self): + # reverse the order of variables + def f(a, b): return (a@b+1).realize() + for i in range(1, 5): + for j in range(1, 5): + vi = Variable("i", 1, 10).bind(i) + vj = Variable("j", 1, 10).bind(j) + a = Tensor.rand(j, 3) + b = Tensor.rand(3, i) + symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy() + expected = f(a, b).numpy() + np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6) + def test_shrink(self): for i in range(1, 5): vi = Variable("i", 1, 10).bind(i) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 8721969950..7537c02c2b 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -170,7 +170,7 @@ class Linearizer(Kernel): if isinstance(buf, MemBuffer): self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype)) # add var vals - for var in vars_from_ast(self.ast): + for var in sorted(vars_from_ast(self.ast)): assert var.expr is not None self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)) # define local buffers diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6a2aff9ba0..46568aa92c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -202,6 +202,7 @@ class InterpretedASTRunner(ASTRunner): super().__init__(ast) def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float: + var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {} st = time.perf_counter() ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) et = time.perf_counter() - st @@ -286,8 +287,8 @@ class CompiledASTRunner(ASTRunner): return global_size, local_size def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]: - if var_vals is None: var_vals = {} - var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals + # filter the var_vals + var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {} global_size, local_size = self.launch_dims(var_vals) if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] # TODO: this is copied from get_program diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 2dee917220..db7ebf4b15 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -113,7 +113,7 @@ class MetalBatchExecutor(BatchExecutor): icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i) if i == 0: write_resources.append(b._buf) else: read_resources.append(b._buf) - var_vals_keys = list(var_vals.keys()) + var_vals_keys = sorted(var_vals.keys()) for i,v in enumerate(prg.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i) global_size, local_size = prg.launch_dims(var_vals)