diff --git a/test/test_jit_cases.py b/test/test_jit_cases.py new file mode 100644 index 0000000000..dd474e3532 --- /dev/null +++ b/test/test_jit_cases.py @@ -0,0 +1,78 @@ +import unittest +from tinygrad import TinyJit, Tensor + +# The JIT functions as a "capturing" JIT. +# Whatever kernels ran in the JIT the second run through the function will be the kernels that will run from then on. +# Explicit inputs to the function are updated in the JIT graph to the new inputs. + +# JITs have four tensor types +# 1. Tensors that are explicit in the input, aka what's passed in. TODO: support lists/dicts/classes, anything get_state works on +# 2. Tensors that are explicit in the output, aka what's returned. TODO: same as above +# 3. Tensors that are implicit in the input as a closure. +# 4. Tensors that are implicit in the output because they were assigned to and realized. + +# explicit inputs and outputs are realized on their way in and out of the JIT +# there's a whole bunch of edge cases and weirdness here that needs to be tested and clarified. + +class TestJitCases(unittest.TestCase): + def test_explicit(self): + # this function has an explicit input and an explicit output + @TinyJit + def f(x:Tensor): + ret:Tensor = x*2 + return ret + + for i in range(5): + out = f(Tensor([i])) + self.assertEqual(out.item(), i*2) + + def test_implicit_input(self): + # x is the implicit input (like a weight) + x = Tensor([0]) + + # this function has an implicit input and an explicit output + @TinyJit + def f(): + ret:Tensor = x*2 + return ret + + for i in range(5): + # NOTE: this must be realized here, otherwise the update doesn't happen + # if we were explicitly tracking the implicit input Tensors, we might not need this realize + x.assign(Tensor([i])).realize() + out = f() + self.assertEqual(out.item(), i*2) + + def test_implicit_output(self): + # out is the implicit output (it's assigned to) + out = Tensor([0]) + + # this function has an explicit input and an implicit output + @TinyJit + def f(x:Tensor): + # NOTE: this must be realized here + # if we were explicitly tracking the implicit output Tensors, we might not need this realize + out.assign(x*2).realize() + + for i in range(5): + f(Tensor([i])) + self.assertEqual(out.item(), i*2) + + def test_implicit_io(self): + # x is the implicit input (like a weight) + # out is the implicit output (it's assigned to) + x = Tensor([0]) + out = Tensor([0]) + + # this function has an implicit input and an implicit output + @TinyJit + def f(): + out.assign(x*2).realize() # NOTE: this must be realized here + + for i in range(5): + x.assign(Tensor([i])).realize() + f() + self.assertEqual(out.item(), i*2) + +if __name__ == '__main__': + unittest.main()