diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 6459727a89..0a117b472b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -55,5 +55,35 @@ class TestLinearizer(unittest.TestCase): num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) assert num_ops <= 1, "more alu uops than needed" + def test_zero_fold(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + + a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() + r = Tensor.stack([a, b]) + ast = r.lazydata.op + r = r.realize() # realize an output buffer + k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts) + k.process() + k.upcast() + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) + assert num_ops == 0, "more alu uops than needed" + + @unittest.skip("constant folding not supported yet") + def test_constant_fold(self): + if not isinstance(Device[Device.DEFAULT], Compiled): + self.skipTest("Only Compiled uses linearizer") + + a, b = Tensor(2), Tensor(3) + r = a * b + ast = r.lazydata.op + r = r.realize() # realize an output buffer + k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts) + k.process() + k.linearize() + num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) + assert num_ops <= 0, "more load or alu uops than needed" + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 816ab3a30d..f7ef66986f 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final +from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final, Set import itertools, math, functools from collections import defaultdict from enum import Enum, auto @@ -93,8 +93,8 @@ class Linearizer(OptimizedKernel): render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True) - render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self), - NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b), + render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self, cachable=True), + NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b, cachable=True), MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV), ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD), @@ -133,11 +133,11 @@ class Linearizer(OptimizedKernel): assert valid.min == 1 self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const) elif this_const is not None: - self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const) + self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const, cachable=True) if valid.min == 0 and valid.max == 1: valid_rendered = valid.render(self.render_ops, self) - alt = self.uop(UOps.CONST, localtype, [], invalid_value) - self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE) + alt = self.uop(UOps.CONST, localtype, [], invalid_value, cachable=True) + self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE, cachable=True) else: self.load_cache[key] = self.uop(UOps.LOAD, localtype, [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value)) ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key]) @@ -359,10 +359,35 @@ class Linearizer(OptimizedKernel): # end the global (and maybe local) loop self.uop(UOps.ENDLOOP, None, [], (loop_global_idxs+loop_local_idxs, "global+local") if not self.group_for_reduce else (loop_global_idxs, "global")) + # (recursively) remove childless uops + UOPS_WO_SIDE_EFFECTS = {UOps.CONST, UOps.ALU, UOps.LOAD, UOps.CAST, UOps.GEP} + while 1: + has_child: Set[UOp] = set() + for ru in self.uops: + for vu in ru.vin: + has_child.add(vu) + nu: List[UOp] = [x for x in self.uops if x in has_child or x.uop not in UOPS_WO_SIDE_EFFECTS] + if len(nu) == len(self.uops): break + if DEBUG >= 4: print(f"reduced UOp count from {len(self.uops)} to {len(nu)}") + self.uops = nu + return self def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp: key = (uop, dtype, tuple(vin), arg) + if uop == UOps.STORE and len(vin) == 2 and vin[0] == vin[1]: return vin[0] # self store is noop + if uop == UOps.ALU: + # rewrites. NOTE: the rewritten NEG op is still around... + if arg == BinaryOps.ADD and vin[1].uop == UOps.ALU and vin[1].arg == UnaryOps.NEG: return self.uop(UOps.ALU, dtype, [vin[0], vin[1].vin[0]], BinaryOps.SUB, cachable=cachable) + # constant folding + if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST: return self.uop(UOps.CONST, dtype, [], -vin[0].arg, cachable=True) + # zero folding + for x in [0,1]: + if arg == BinaryOps.ADD and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[1-x] + if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 1.0: return vin[1-x] + if arg == BinaryOps.MUL and vin[x].uop == UOps.CONST and vin[x].arg == 0.0: return vin[x] + if arg == BinaryOps.SUB and vin[1].uop == UOps.CONST and vin[1].arg == 0.0: return vin[0] + if arg == BinaryOps.DIV and vin[1].uop == UOps.CONST and vin[1].arg == 1.0: return vin[0] if cachable and key in self.saved_exprs: return self.saved_exprs[key] self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops))) if DEBUG >= 4: print(self.uops[-1]) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 26bf8b4f97..a7373ff974 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -105,7 +105,7 @@ class Interpreted: if DEBUG >= 3: st = time.perf_counter() ret = self.from_underlying(self.fxn_for_op[ast.op](*([self.to_underlying(x) for x in srcs] + ([ast.arg] if ast.arg is not None else [])))) if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op: ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.to_underlying(ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype. - if DEBUG >= 3: print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "") + if DEBUG >= 5 or (self.buffer != FlopCounter and DEBUG >= 3): print(f"*** {'exec' if created_context else ' '} {GlobalCounters.mem_used/1e9:5.2f} GB {(time.perf_counter()-st)*1e3:7.2f} ms op: {ast.op:20s} out({ret.dtype.name}): {str(ret._buf.shape) if hasattr(ret._buf, 'shape') else str(len(ret._buf)):30s} in({len(srcs)}):", list(set(x._buf.shape if hasattr(x._buf, 'shape') else len(x._buf) for x in srcs)), ast.arg if ast.arg is not None else "") if not created_context: context[ast] = ret if output is not None and output.output_buffer is not None: assert output.output_buffer.size == ret.size, output.output_buffer.dtype == ret.dtype diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 0e418c5a7d..bda2449807 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -173,7 +173,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T elif uop == UOps.ALU: assert dtype is not None val = lang.code_for_op[args](*[r[x] for x in vin]) - if child_count[u] == 1: r[u] = val + assert child_count[u] != 0, f"childless ALU op found {u}" + if child_count[u] <= 1: r[u] = val else: r[u] = ssa('alu') kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 36db35961a..23006cac16 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -114,7 +114,7 @@ def _reshape(view: View, new_shape:Tuple[int, ...]) -> Tuple[View, bool]: new_view = View(new_shape) if view.contiguous: return new_view, False # NOTE: if it's contiguous it can't have an offset if (merged_view := merge_views(view, new_view)) is not None: return merged_view, False - if DEBUG >= 4: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") + if DEBUG >= 5: print(f"WARNING: creating new view with reshape {view} -> {new_shape}") return new_view, True @functools.lru_cache(maxsize=None) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b86406cb59..1dc1ada3c4 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -603,6 +603,7 @@ class Tensor: def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self) + if x.__class__ is not Tensor and x == -1.0: return -self return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: