diff --git a/test/null/test_uop_graph.py b/test/null/test_uop_graph.py index 5548ba1a82..08e27c7d4f 100644 --- a/test/null/test_uop_graph.py +++ b/test/null/test_uop_graph.py @@ -879,5 +879,40 @@ class TestUOpGetItem(unittest.TestCase): result = p[:, :] self.assertNotEqual(result.op, Ops.INDEX) +class TestUOpBroadcast(unittest.TestCase): + def test_broadcast_row(self): + a = UOp.const(dtypes.float, 1, shape=(4, 8)) + b = UOp.const(dtypes.float, 2, shape=(4, 1)) + c = a + b + self.assertEqual(c.shape, (4, 8)) + self.assertEqual(c.op, Ops.ADD) + + def test_broadcast_col(self): + a = UOp.const(dtypes.float, 1, shape=(4, 8)) + b = UOp.const(dtypes.float, 2, shape=(1, 8)) + c = a + b + self.assertEqual(c.shape, (4, 8)) + self.assertEqual(c.op, Ops.ADD) + + def test_broadcast_lower_dim(self): + a = UOp.const(dtypes.float, 1, shape=(4, 8)) + b = UOp.const(dtypes.float, 2, shape=(8,)) + c = a * b + self.assertEqual(c.shape, (4, 8)) + self.assertEqual(c.op, Ops.MUL) + + def test_broadcast_scalar(self): + a = UOp.const(dtypes.float, 1, shape=(4, 8)) + c = a * 2 + self.assertEqual(c.shape, (4, 8)) + self.assertEqual(c.op, Ops.MUL) + + def test_broadcast_symbolic_same_shape(self): + t = Variable("t", 1, 10) + a = UOp.const(dtypes.float, 1, shape=(1, 1, t)) + b = UOp.const(dtypes.float, 2, shape=(1, 1, t)) + c = a + b + self.assertEqual(c.op, Ops.ADD) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index a7b0b0bc1e..693484f097 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -3,18 +3,12 @@ from __future__ import annotations from typing import TYPE_CHECKING, Self from tinygrad.uop import Ops from tinygrad.helpers import prod, argfix, argsort, flatten, dedup, make_tuple, ceildiv -from tinygrad.uop.ops import resolve, smax +from tinygrad.uop.ops import resolve, smax, _align_left if TYPE_CHECKING: from tinygrad.uop.ops import sint -def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]: - # unsqueeze left to make every shape same length - max_dim = max(len(shape) for shape in shapes) - return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes) - - class MovementMixin: # required to implement def _mop(self, op: Ops, arg) -> Self: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 79cf2ac7fe..571d111f42 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -11,8 +11,8 @@ from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin -from tinygrad.mixin.movement import _align_left from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable +from tinygrad.uop.ops import _broadcast_shape from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars from tinygrad.device import Buffer from tinygrad.engine.realize import run_schedule @@ -83,9 +83,6 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: assert isinstance(ret, Tensor), "sum didn't return a Tensor" return ret -def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: - return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) - def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor: # reduce such that if mask contains repeated indices the last one remains for dim in reversed(axes): diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 88bcebe564..4f9f270874 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -45,6 +45,11 @@ def _suop(lst, uop_fxn, python_fxn): def smax(*lst) -> sint: return _suop(argfix(*lst), UOp.maximum, max) def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min) def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x) +def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]: + max_dim = max(len(s) for s in shapes) + return tuple((1,)*(max_dim-len(s))+s for s in shapes) +def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]: + return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes))) def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop @@ -473,9 +478,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass): assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast" return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs)) def alu(self, op, *src:UOp, **kwargs): - out_dtype = (self, *src)[-1].dtype + all_srcs = (self, *src) + # broadcast shaped operands to a common shape (None and () are falsy, so only real shapes participate) + if (shapes := [s for x in all_srcs if (s:=x._shape)]) and not all_same(shapes): + out_shape = _broadcast_shape(*shapes) + all_srcs = tuple(x._broadcast_to(out_shape) if x._shape else x for x in all_srcs) + out_dtype = all_srcs[-1].dtype if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool - return UOp(op, out_dtype, (self,)+src, **kwargs) + return UOp(op, out_dtype, all_srcs, **kwargs) @staticmethod def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None): if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b