From aa9b013d7942fa12270e5f726e9f56180bc77afd Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 2 Mar 2024 10:37:14 -0800 Subject: [PATCH] add constant folding for WHERE in uops (#3584) * add constant folding for WHERE in uops * prereqs for generic constant folding * fix test * disable slow overflow logic * make that test faster --- test/test_linearizer.py | 10 +++++++++- test/test_ops.py | 2 +- test/test_uops.py | 18 +++++++++++++++++- tinygrad/codegen/uops.py | 27 ++++++++++++++++++++++++++- tinygrad/runtime/ops_python.py | 28 +++------------------------- 5 files changed, 56 insertions(+), 29 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 01375ba343..be8b65305d 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -27,6 +27,14 @@ class TestLinearizer(unittest.TestCase): np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) + def test_load_removed(self): + a = Tensor.rand(1).realize() + b = Tensor.rand(1).realize() + ta = Tensor.where(Tensor(True), a, b).numpy() + tb = Tensor.where(Tensor(False), a, b).numpy() + np.testing.assert_equal(a.numpy(), ta) + np.testing.assert_equal(b.numpy(), tb) + def test_load_dedup(self): # for different leaves in the AST, the same loads may occur. @@ -209,7 +217,7 @@ class TestLinearizer(unittest.TestCase): c0 = UOp(UOps.CONST, dtypes.float, vin=(), arg=0.0) c1 = UOp(UOps.CONST, dtypes.float, vin=(), arg=1.0) assert helper_test_simplify(UOps.ALU, dtypes.float, vin=(UOp(UOps.CONST, dtypes.bool, vin=(), arg=True), c0, c1), - arg=TernaryOps.WHERE).uop == UOps.ALU + arg=TernaryOps.WHERE).uop == UOps.CONST def helper_realized_ast(r:Tensor): s = create_schedule([r.lazydata]) diff --git a/test/test_ops.py b/test/test_ops.py index 3a304697b8..53c6b4aaea 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1082,7 +1082,7 @@ class TestOps(unittest.TestCase): @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_padded_conv3d(self): - helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], + helper_test_op([(1,4,5,5,5), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(), lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5) diff --git a/test/test_uops.py b/test/test_uops.py index 8856d6d7de..904d023310 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -7,7 +7,7 @@ from tinygrad.device import Buffer, Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp -from tinygrad.runtime.ops_python import exec_alu +from tinygrad.codegen.uops import exec_alu from test.test_dtype import is_dtype_supported def _uops_to_prg(uops): @@ -113,5 +113,21 @@ class TestExecALU(TestUOps): def test_sqrt(self): self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0) + @unittest.skip("not enabled because it's slow") + def test_overflow(self): + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244) + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (256, 0)), 0) + self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1)), 255) + self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.uint8, (0, 1000)), 24) + + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (127, 0)), 127) + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-128, 0)), -128) + self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-100, 100)), 56) + self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-1000, 0)), 24) + self.assertEqual(exec_alu(BinaryOps.SUB, dtypes.int8, (-130, 0)), 126) + + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (1.0, 1.0)), 2) + self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.int8, (-math.exp2(7), 0)), -128) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 9560ae5398..f1849ef198 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools +import functools, math from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, all_same @@ -25,6 +25,30 @@ class UOp: def __repr__(self): return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}" +def exec_alu(arg, dtype, p): + if arg == TernaryOps.WHERE: ret = p[1] if p[0] else p[2] + elif arg == UnaryOps.LOG2: ret = math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan + elif arg == UnaryOps.EXP2: + try: ret = math.exp(p[0]*math.log(2)) + except OverflowError: ret = math.inf + elif arg == UnaryOps.SQRT: ret = math.sqrt(p[0]) if p[0] >= 0 else math.nan + elif arg == UnaryOps.SIN: ret = math.sin(p[0]) + elif arg == UnaryOps.NEG: ret = -p[0] + elif arg == BinaryOps.MUL: ret = p[0]*p[1] + elif arg == BinaryOps.ADD: ret = p[0]+p[1] + elif arg == BinaryOps.SUB: ret = p[0]-p[1] + elif arg == BinaryOps.XOR: ret = p[0]^p[1] + elif arg == BinaryOps.MAX: ret = max(p[0], p[1]) + elif arg == BinaryOps.CMPEQ: ret = p[0] == p[1] + elif arg == BinaryOps.CMPLT: ret = p[0] < p[1] + elif arg == BinaryOps.DIV: ret = p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan) + elif arg == BinaryOps.MOD: ret = p[0]%p[1] + return ret + #else: raise NotImplementedError(f"no support for {arg}") + #if not dtypes.is_int(dtype): return ret + #adjusted = 0 if dtypes.is_unsigned(dtype) else 2 ** (dtype.itemsize * 8 - 1) + #return (ret + adjusted) % 2 ** (dtype.itemsize * 8) - adjusted + def uop_alu_resolve(u:UOp) -> sint: if u.uop == UOps.CONST: return u.arg elif u.uop == UOps.DEFINE_VAR: return u.arg @@ -68,6 +92,7 @@ class UOpGraph: # constant folding if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before) if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop + if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2] if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype): return self.add(UOps.CONST, dtype, arg=vin[0].arg * vin[1].arg, insert_before=insert_before) # zero folding diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index ad9a41cb63..bca0c58b4f 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -2,36 +2,14 @@ # works to test the tensor cores, and all the uops in general # this is the (living) definition of uops from typing import Tuple, List, Optional, Any, Dict -import pickle, base64, itertools, time, math, struct +import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Allocator, Compiler -from tinygrad.codegen.uops import UOp, UOps -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.codegen.uops import UOp, UOps, exec_alu +from tinygrad.ops import BinaryOps, TernaryOps from tinygrad.codegen.kernel import LinearizerOptions -def exec_alu(arg, dtype, p): - # TODO: make this complete and correctly honor the dtypes - # TODO: use this for constant folding - if arg == TernaryOps.WHERE: return p[1] if p[0] else p[2] - if arg == UnaryOps.LOG2: return math.log2(p[0]) if p[0] > 0 else -math.inf if p[0] == 0 else math.nan - if arg == UnaryOps.EXP2: - try: return math.exp(p[0]*math.log(2)) - except OverflowError: return math.inf - if arg == UnaryOps.SQRT: return math.sqrt(p[0]) if p[0] >= 0 else math.nan - if arg == UnaryOps.SIN: return math.sin(p[0]) - if arg == UnaryOps.NEG: return -p[0] - if arg == BinaryOps.MUL: return p[0]*p[1] - if arg == BinaryOps.ADD: return p[0]+p[1] - if arg == BinaryOps.SUB: return p[0]-p[1] - if arg == BinaryOps.XOR: return p[0]^p[1] - if arg == BinaryOps.MAX: return max(p[0], p[1]) - if arg == BinaryOps.CMPEQ: return p[0] == p[1] - if arg == BinaryOps.CMPLT: return p[0] < p[1] - if arg == BinaryOps.DIV: return p[0]//p[1] if dtypes.is_int(dtype) else (p[0]/p[1] if p[1] != 0 else math.nan) - if arg == BinaryOps.MOD: return p[0]%p[1] - raise NotImplementedError(f"no support for {arg}") - def _load(m, i): if i<0 or i>=len(m): raise IndexError(f"load out of bounds, size is {len(m)} and access is {i}") return m[i]