Revert "Use the reduceop dtype to define the acc in linearizer (#2625)" (#2783)

This reverts commit f3ed96a929.
This commit is contained in:
chenyu
2023-12-15 16:29:10 -05:00
committed by GitHub
parent f3ed96a929
commit e4bbbc5bc3
3 changed files with 11 additions and 48 deletions

View File

@@ -1,5 +1,4 @@
from tinygrad.device import JITRunner
from tinygrad.helpers import DTYPES_DICT, dtypes
from tinygrad.ops import LazyOp, LoadOps
from tinygrad.nn.state import get_parameters
@@ -24,7 +23,4 @@ def assert_jit_cache_len(fxn, expected_len):
assert len(fxn.jit_cache) == 1
# until we have a better way of typing the prg in JitItem
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
float_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_float(v)]
int_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_int(v)]
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len

View File

@@ -1,19 +1,17 @@
# ruff: noqa: E501
import numpy as np
import unittest, os
from hypothesis import given, strategies as st
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, ReduceOps, TernaryOps, get_lazyop_info
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.helpers import DType, dtypes, prod
from test.helpers import float_dtypes, int_dtypes
from tinygrad.helpers import dtypes, prod
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
class TestLinearizer(unittest.TestCase):
@@ -108,35 +106,6 @@ class TestLinearizer(unittest.TestCase):
lin = Linearizer(sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
reduce_ops = (Tensor.max, Tensor.min, Tensor.sum)
@given(st.sampled_from(float_dtypes+int_dtypes), st.sampled_from(reduce_ops))
def test_reduce_acc(self, d:DType, op):
a = Tensor.rand(1024,1024, dtype=d)
out = op(a)
ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast
reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0]
uops = Linearizer(ast).linearize().uops
phi = [u for u in uops if u.uop == UOps.PHI][0]
assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype
@unittest.skip("TODO different memory and mulacc dtypes are not working yet")
@given(st.sampled_from(float_dtypes), st.sampled_from(float_dtypes))
def test_mulacc_midcast(self, d1:DType, d2:DType):
a = Tensor.rand(1024,1024, dtype=d1)
b = Tensor.rand(1024,1024, dtype=d1)
out = (a*b).cast(d2).sum(-1)
ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast
reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0]
uops = Linearizer(ast).linearize().uops
mulacc = [u for u in uops if u.uop == UOps.ALU and u.arg == TernaryOps.MULACC][0]
phi = [u for u in uops if u.uop == UOps.PHI][0]
assert mulacc.vin[0].dtype == mulacc.vin[1].dtype == d2
assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype
def test_simplify_uop(self):
def helper_test_simplify(uop, dtype, vin, arg=None):
ast = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))

View File

@@ -6,7 +6,7 @@ from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info, vars_from_ast
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
from tinygrad.codegen.kernel import LocalBuffer, Kernel
@@ -50,10 +50,9 @@ class Linearizer(Kernel):
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
def get_reduce_acc(self):
dtype = get_lazyop_info(self.reduceop).dtype.scalar()
if cast(LazyOp,self.reduceop).op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
if cast(LazyOp,self.reduceop).op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
def get_reduce_acc(self, op, dtype:DType):
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
@@ -84,9 +83,8 @@ class Linearizer(Kernel):
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
else:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
localtype = get_lazyop_info(self.reduceop).dtype.scalar() if acc is not None else buf.dtype
if isinstance(localtype, ImageDType): localtype = dtypes.float
if amt > 1: localtype = localtype.vec(amt)
localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt)
if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt)
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
@@ -249,7 +247,7 @@ class Linearizer(Kernel):
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc())
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
if self.tensor_core:
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
@@ -355,7 +353,7 @@ class Linearizer(Kernel):
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc())
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501
# late reduce loop
loop_ctx = render_loop(end_local_idxs)