diff --git a/test/helpers.py b/test/helpers.py index 575e64390d..6f31dd1941 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,5 +1,4 @@ from tinygrad.device import JITRunner -from tinygrad.helpers import DTYPES_DICT, dtypes from tinygrad.ops import LazyOp, LoadOps from tinygrad.nn.state import get_parameters @@ -24,7 +23,4 @@ def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache) == 1 # until we have a better way of typing the prg in JitItem assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph') - assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len - -float_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_float(v)] -int_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_int(v)] + assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len \ No newline at end of file diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7e1798f167..50bcd53c40 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1,19 +1,17 @@ # ruff: noqa: E501 import numpy as np import unittest, os -from hypothesis import given, strategies as st from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores from tinygrad.codegen.linearizer import Linearizer, UOp, UOps from tinygrad.device import Compiled, Device, Buffer -from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, ReduceOps, TernaryOps, get_lazyop_info +from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector from tinygrad.realize import run_schedule -from tinygrad.helpers import DType, dtypes, prod -from test.helpers import float_dtypes, int_dtypes +from tinygrad.helpers import dtypes, prod @unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends") class TestLinearizer(unittest.TestCase): @@ -108,35 +106,6 @@ class TestLinearizer(unittest.TestCase): lin = Linearizer(sched[0].ast) assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" - reduce_ops = (Tensor.max, Tensor.min, Tensor.sum) - @given(st.sampled_from(float_dtypes+int_dtypes), st.sampled_from(reduce_ops)) - def test_reduce_acc(self, d:DType, op): - a = Tensor.rand(1024,1024, dtype=d) - out = op(a) - - ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast - reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0] - uops = Linearizer(ast).linearize().uops - phi = [u for u in uops if u.uop == UOps.PHI][0] - - assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype - - @unittest.skip("TODO different memory and mulacc dtypes are not working yet") - @given(st.sampled_from(float_dtypes), st.sampled_from(float_dtypes)) - def test_mulacc_midcast(self, d1:DType, d2:DType): - a = Tensor.rand(1024,1024, dtype=d1) - b = Tensor.rand(1024,1024, dtype=d1) - out = (a*b).cast(d2).sum(-1) - - ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast - reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0] - uops = Linearizer(ast).linearize().uops - mulacc = [u for u in uops if u.uop == UOps.ALU and u.arg == TernaryOps.MULACC][0] - phi = [u for u in uops if u.uop == UOps.PHI][0] - - assert mulacc.vin[0].dtype == mulacc.vin[1].dtype == d2 - assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype - def test_simplify_uop(self): def helper_test_simplify(uop, dtype, vin, arg=None): ast = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),)))) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 347631810d..9248db670c 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -6,7 +6,7 @@ from enum import Enum, auto from dataclasses import dataclass from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten -from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info, vars_from_ast +from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode from tinygrad.codegen.kernel import LocalBuffer, Kernel @@ -50,10 +50,9 @@ class Linearizer(Kernel): def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val - def get_reduce_acc(self): - dtype = get_lazyop_info(self.reduceop).dtype.scalar() - if cast(LazyOp,self.reduceop).op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 - if cast(LazyOp,self.reduceop).op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False + def get_reduce_acc(self, op, dtype:DType): + if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0 + elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b), MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), @@ -84,9 +83,8 @@ class Linearizer(Kernel): (g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None else: g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs) - localtype = get_lazyop_info(self.reduceop).dtype.scalar() if acc is not None else buf.dtype - if isinstance(localtype, ImageDType): localtype = dtypes.float - if amt > 1: localtype = localtype.vec(amt) + localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt) + if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt) e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars) @@ -249,7 +247,7 @@ class Linearizer(Kernel): fake_reduce_idxs = [x*0 for x in reduce_idxs] # define accumulator - acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc()) + acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype)) if self.tensor_core: def calc_tc_idxs(local_size: int, aliases: List[List[int]]): @@ -355,7 +353,7 @@ class Linearizer(Kernel): # NOTE: this structure is the same as the reduce op above # define late accumulator - acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc()) + acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501 # late reduce loop loop_ctx = render_loop(end_local_idxs)