mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
This reverts commit f3ed96a929.
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from tinygrad.device import JITRunner
|
||||
from tinygrad.helpers import DTYPES_DICT, dtypes
|
||||
from tinygrad.ops import LazyOp, LoadOps
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
@@ -24,7 +23,4 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache) == 1
|
||||
# until we have a better way of typing the prg in JitItem
|
||||
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
float_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_float(v)]
|
||||
int_dtypes = [v for v in DTYPES_DICT.values() if dtypes.is_int(v)]
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
@@ -1,19 +1,17 @@
|
||||
# ruff: noqa: E501
|
||||
import numpy as np
|
||||
import unittest, os
|
||||
from hypothesis import given, strategies as st
|
||||
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOp, UOps
|
||||
from tinygrad.device import Compiled, Device, Buffer
|
||||
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, ReduceOps, TernaryOps, get_lazyop_info
|
||||
from tinygrad.ops import BufferOps, MemBuffer, ConstBuffer, LazyOp, LoadOps, TernaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.helpers import DType, dtypes, prod
|
||||
from test.helpers import float_dtypes, int_dtypes
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "linearizer is only for compiled backends")
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
@@ -108,35 +106,6 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Linearizer(sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
reduce_ops = (Tensor.max, Tensor.min, Tensor.sum)
|
||||
@given(st.sampled_from(float_dtypes+int_dtypes), st.sampled_from(reduce_ops))
|
||||
def test_reduce_acc(self, d:DType, op):
|
||||
a = Tensor.rand(1024,1024, dtype=d)
|
||||
out = op(a)
|
||||
|
||||
ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast
|
||||
reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0]
|
||||
uops = Linearizer(ast).linearize().uops
|
||||
phi = [u for u in uops if u.uop == UOps.PHI][0]
|
||||
|
||||
assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype
|
||||
|
||||
@unittest.skip("TODO different memory and mulacc dtypes are not working yet")
|
||||
@given(st.sampled_from(float_dtypes), st.sampled_from(float_dtypes))
|
||||
def test_mulacc_midcast(self, d1:DType, d2:DType):
|
||||
a = Tensor.rand(1024,1024, dtype=d1)
|
||||
b = Tensor.rand(1024,1024, dtype=d1)
|
||||
out = (a*b).cast(d2).sum(-1)
|
||||
|
||||
ast = [si for si in out.lazydata.schedule() if si.ast.op not in LoadOps][0].ast
|
||||
reduceop = [op for op in ast.get_lazyops() if op.op in ReduceOps][0]
|
||||
uops = Linearizer(ast).linearize().uops
|
||||
mulacc = [u for u in uops if u.uop == UOps.ALU and u.arg == TernaryOps.MULACC][0]
|
||||
phi = [u for u in uops if u.uop == UOps.PHI][0]
|
||||
|
||||
assert mulacc.vin[0].dtype == mulacc.vin[1].dtype == d2
|
||||
assert phi.dtype == phi.vin[0].dtype == phi.vin[1].dtype == get_lazyop_info(reduceop).dtype
|
||||
|
||||
def test_simplify_uop(self):
|
||||
def helper_test_simplify(uop, dtype, vin, arg=None):
|
||||
ast = LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=42, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(), strides=(), offset=0, mask=None, contiguous=True),))))
|
||||
|
||||
@@ -6,7 +6,7 @@ from enum import Enum, auto
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, getenv, all_same, to_function_name, flatten
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, get_lazyop_info, vars_from_ast
|
||||
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps, vars_from_ast
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, VariableOrNum, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
from tinygrad.codegen.kernel import LocalBuffer, Kernel
|
||||
@@ -50,10 +50,9 @@ class Linearizer(Kernel):
|
||||
|
||||
def cast(self, val: UOp, dtype) -> UOp: return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
|
||||
|
||||
def get_reduce_acc(self):
|
||||
dtype = get_lazyop_info(self.reduceop).dtype.scalar()
|
||||
if cast(LazyOp,self.reduceop).op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
if cast(LazyOp,self.reduceop).op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
|
||||
def get_reduce_acc(self, op, dtype:DType):
|
||||
if op == ReduceOps.SUM: return 0.0 if dtypes.is_float(dtype) else 0
|
||||
elif op == ReduceOps.MAX: return -math.inf if dtypes.is_float(dtype) else -2**31 if dtypes.is_int(dtype) else False
|
||||
|
||||
render_ops: Any = { Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr], NumNode: lambda self, ops, ctx: ctx.const(self.b),
|
||||
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL),
|
||||
@@ -84,9 +83,8 @@ class Linearizer(Kernel):
|
||||
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
|
||||
else:
|
||||
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
|
||||
localtype = get_lazyop_info(self.reduceop).dtype.scalar() if acc is not None else buf.dtype
|
||||
if isinstance(localtype, ImageDType): localtype = dtypes.float
|
||||
if amt > 1: localtype = localtype.vec(amt)
|
||||
localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt)
|
||||
if isinstance(buf.dtype, ImageDType): localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt)
|
||||
|
||||
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
|
||||
|
||||
@@ -249,7 +247,7 @@ class Linearizer(Kernel):
|
||||
fake_reduce_idxs = [x*0 for x in reduce_idxs]
|
||||
|
||||
# define accumulator
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc())
|
||||
acc = self.global_load(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype))
|
||||
|
||||
if self.tensor_core:
|
||||
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
|
||||
@@ -355,7 +353,7 @@ class Linearizer(Kernel):
|
||||
# NOTE: this structure is the same as the reduce op above
|
||||
|
||||
# define late accumulator
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc())
|
||||
acc = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype)) # noqa: E501
|
||||
|
||||
# late reduce loop
|
||||
loop_ctx = render_loop(end_local_idxs)
|
||||
|
||||
Reference in New Issue
Block a user