mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 21:38:10 -05:00
multi reduce Tensor.var passing verify_lazyop (#5346)
* what about this * reset late gate
This commit is contained in:
@@ -101,6 +101,25 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(mutable_bufs) == len(stores) == 2
|
||||
assert [u.arg[0] for u in mutable_bufs] == [0, 1]
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "remu doesn't have multiple wave syncs yet")
|
||||
def test_var_multireduce(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(3, 27, 32).realize()
|
||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||
first_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1)).expand((3, 27, 32, 32))))
|
||||
first_reduce = LazyOp(ReduceOps.SUM, (first_x,), (3,))
|
||||
mean = first_reduce * LazyOp(BufferOps.CONST, (), ConstBuffer(0.03125, dtypes.float, ShapeTracker.from_shape(()).reshape((1, 1, 1, 1)).expand((3, 27, 32, 1)))) # noqa: E501
|
||||
# store = LazyOp(BufferOps.STORE, (mean,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 32, 1))))
|
||||
# verify_lazyop(store)
|
||||
second_x = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.float, x.lazydata.st.reshape((3, 27, 32, 1))))
|
||||
squares = (second_x-mean)*(second_x-mean)
|
||||
squares_sum = LazyOp(ReduceOps.SUM, (squares,), (2,))
|
||||
store = LazyOp(BufferOps.STORE, (squares_sum,), MemBuffer(0, dtypes.float, ShapeTracker.from_shape((3, 27, 1, 1))))
|
||||
helper_linearizer_ast((store, ), [x])
|
||||
# tinygrad ref
|
||||
y_tiny = x.var(axis=2, correction=0)
|
||||
np.testing.assert_allclose(y_tiny.numpy(), x.numpy().var(axis=2, ddof=0), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
def test_end_local(self):
|
||||
|
||||
@@ -4,14 +4,11 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
#from tinygrad.codegen.lowerer import Lowerer
|
||||
from tinygrad.engine.graph import print_tree
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, BufferOps, MemBuffer, LazyOp, ReduceOps, verify_lazyop
|
||||
from tinygrad.ops import BufferOps, MemBuffer, LazyOp, ReduceOps, verify_lazyop
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
class LazyOp(LazyOp):
|
||||
def __add__(self, other:LazyOp): return LazyOp(BinaryOps.ADD, (self, other))
|
||||
|
||||
class InvalidLazyOpException(Exception): pass
|
||||
def lower(*ast:LazyOp):
|
||||
if DEBUG >= 3:
|
||||
|
||||
@@ -264,6 +264,8 @@ class Linearizer(Kernel):
|
||||
def render_reduceop(self, reduceop:LazyOp, accs:Dict[LazyOp, List[UOp]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
|
||||
global_idxs, local_idxs, upcast_idxs, full_upcast_idxs, reduce_idxs, fake_reduce_idxs,
|
||||
alias_buf_idxs:List[Tuple[int, int, List]]) -> Tuple[List[NumNode|Variable], List[NumNode|Variable]]:
|
||||
# reset late_gate
|
||||
self.late_gate = None
|
||||
# reduce loop
|
||||
loop_ctx = self.render_loop(reduce_idxs, (i:=self.reduceops.index(reduceop))*2+2, True)
|
||||
|
||||
|
||||
@@ -77,6 +77,12 @@ class LazyOp:
|
||||
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
|
||||
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)
|
||||
|
||||
# TODO: support non-lazyop
|
||||
def __add__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, x))
|
||||
def __sub__(self, x:LazyOp): return LazyOp(BinaryOps.ADD, (self, -x))
|
||||
def __mul__(self, x:LazyOp): return LazyOp(BinaryOps.MUL, (self, x))
|
||||
def __neg__(self): return LazyOp(UnaryOps.NEG, (self,))
|
||||
|
||||
# **************** independent FlopCounter ****************
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user