support shape broadcast in UOp.alu (#15442)

i think it can integrate tighter, but now Tensor also does ufix from UOp and implicit dtype upcast
This commit is contained in:
chenyu
2026-03-24 10:14:57 -04:00
committed by GitHub
parent a33ac869aa
commit b7960841af
4 changed files with 49 additions and 13 deletions

View File

@@ -879,5 +879,40 @@ class TestUOpGetItem(unittest.TestCase):
result = p[:, :]
self.assertNotEqual(result.op, Ops.INDEX)
class TestUOpBroadcast(unittest.TestCase):
def test_broadcast_row(self):
a = UOp.const(dtypes.float, 1, shape=(4, 8))
b = UOp.const(dtypes.float, 2, shape=(4, 1))
c = a + b
self.assertEqual(c.shape, (4, 8))
self.assertEqual(c.op, Ops.ADD)
def test_broadcast_col(self):
a = UOp.const(dtypes.float, 1, shape=(4, 8))
b = UOp.const(dtypes.float, 2, shape=(1, 8))
c = a + b
self.assertEqual(c.shape, (4, 8))
self.assertEqual(c.op, Ops.ADD)
def test_broadcast_lower_dim(self):
a = UOp.const(dtypes.float, 1, shape=(4, 8))
b = UOp.const(dtypes.float, 2, shape=(8,))
c = a * b
self.assertEqual(c.shape, (4, 8))
self.assertEqual(c.op, Ops.MUL)
def test_broadcast_scalar(self):
a = UOp.const(dtypes.float, 1, shape=(4, 8))
c = a * 2
self.assertEqual(c.shape, (4, 8))
self.assertEqual(c.op, Ops.MUL)
def test_broadcast_symbolic_same_shape(self):
t = Variable("t", 1, 10)
a = UOp.const(dtypes.float, 1, shape=(1, 1, t))
b = UOp.const(dtypes.float, 2, shape=(1, 1, t))
c = a + b
self.assertEqual(c.op, Ops.ADD)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -3,18 +3,12 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, argsort, flatten, dedup, make_tuple, ceildiv
from tinygrad.uop.ops import resolve, smax
from tinygrad.uop.ops import resolve, smax, _align_left
if TYPE_CHECKING:
from tinygrad.uop.ops import sint
def _align_left(*shapes: tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
class MovementMixin:
# required to implement
def _mop(self, op: Ops, arg) -> Self:

View File

@@ -11,8 +11,8 @@ from tinygrad.helpers import IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, ASM_GEMM
from tinygrad.helpers import suppress_finalizing, disable_gc
from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin.movement import _align_left
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.uop.ops import _broadcast_shape
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Buffer
from tinygrad.engine.realize import run_schedule
@@ -83,9 +83,6 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
# reduce such that if mask contains repeated indices the last one remains
for dim in reversed(axes):

View File

@@ -45,6 +45,11 @@ def _suop(lst, uop_fxn, python_fxn):
def smax(*lst) -> sint: return _suop(argfix(*lst), UOp.maximum, max)
def smin(*lst) -> sint: return _suop(argfix(*lst), UOp.minimum, min)
def srender(x:sint) -> str: return x.render() if isinstance(x, UOp) else str(x)
def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
max_dim = max(len(s) for s in shapes)
return tuple((1,)*(max_dim-len(s))+s for s in shapes)
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def ssimplify(uop:sint): return uop.ssimplify() if isinstance(uop, UOp) else uop
def sym_infer(uop: UOp|int, var_vals: dict[str, int]) -> int: return uop.sym_infer(var_vals) if isinstance(uop, UOp) else uop
@@ -473,9 +478,14 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
assert all(x.arg[-1] == AxisType.UPCAST for x in rngs), "all contract ranges must be upcast"
return UOp(Ops.CONTRACT, dtype=self.dtype.vec(prod([x.vmax+1 for x in rngs])), src=(self,), arg=tuple((x.arg[0], x.vmax+1) for x in rngs))
def alu(self, op, *src:UOp, **kwargs):
out_dtype = (self, *src)[-1].dtype
all_srcs = (self, *src)
# broadcast shaped operands to a common shape (None and () are falsy, so only real shapes participate)
if (shapes := [s for x in all_srcs if (s:=x._shape)]) and not all_same(shapes):
out_shape = _broadcast_shape(*shapes)
all_srcs = tuple(x._broadcast_to(out_shape) if x._shape else x for x in all_srcs)
out_dtype = all_srcs[-1].dtype
if op in {Ops.CMPLT, Ops.CMPNE, Ops.CMPEQ}: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
return UOp(op, out_dtype, (self,)+src, **kwargs)
return UOp(op, out_dtype, all_srcs, **kwargs)
@staticmethod
def const(dtype:DType, b:ConstLike, device:str|tuple[str, ...]|None=None, shape:tuple[sint, ...]|None=None):
if isinstance(b, UOp): return b.unbind()[0] if b.op is Ops.BIND else b