more uop programs (#13007)

* more uop program

* test_matmul_relu

* tests fix
This commit is contained in:
George Hotz
2025-10-30 14:57:59 +08:00
committed by GitHub
parent c18b283f58
commit e456f2cb1e
7 changed files with 63 additions and 16 deletions

View File

@@ -569,11 +569,21 @@ class TestUOpPrograms(unittest.TestCase):
def _run(self, prog:UOp, *tensors:Tensor):
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
def test_simple(self):
out = Tensor.empty(10,10,dtype=dtypes.int)
ptr = UOp.placeholder(out.dtype, out.shape, slot=0)
i, j = UOp.range(10, axis_id=0), UOp.range(10, axis_id=1)
prog = ptr[i,j].set(42).end(i,j)
self._run(prog.sink(), out)
with Context(DEBUG=0): self.assertTrue((out == 42).all().item())
def test_matmul(self):
a = Tensor.rand(10,10)
b = Tensor.rand(10,10)
a = Tensor.randn(10,10)
b = Tensor.randn(10,10)
c = Tensor.empty(10,10)
ref = a@b
ref = (a@b)
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
# C[i,j] = sum_k A[i,k] * B[k,j]
@@ -581,16 +591,16 @@ class TestUOpPrograms(unittest.TestCase):
M = N = K = 10
DT = dtypes.float32
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
i = UOp.range(M, axis_id=0) # rows of A/C
j = UOp.range(N, axis_id=1) # cols of B/C
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
# Placeholders (bind slots explicitly)
A = UOp.placeholder(DT, (M, K), slot=0)
B = UOp.placeholder(DT, (K, N), slot=1)
C = UOp.placeholder(DT, (M, N), slot=2)
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
i = UOp.range(M, axis_id=0) # rows of A/C
j = UOp.range(N, axis_id=1) # cols of B/C
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
# Zero-init: write a scalar 0 to each (i,j).
C = C[i, j].set(0.0)
@@ -601,9 +611,27 @@ class TestUOpPrograms(unittest.TestCase):
prog = C.end(i, j, k)
# run program
# TODO: make this work with opts_to_apply
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
def test_matmul_relu(self):
a, b, c = Tensor.randn(10,10), Tensor.randn(10,10), Tensor.empty(10,10)
ref = (a@b).relu()
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
A, B, C = a.uop.placeholder_like(0), b.uop.placeholder_like(1), c.uop.placeholder_like(2)
i, j, k = UOp.range(10, 0), UOp.range(10, 1), UOp.range(10, 2, axis_type=AxisType.REDUCE)
C = C[i, j].set(0.0)
C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j], end=k)
C = C[i, j].set(C[i, j].maximum(0.0))
prog = C.end(i, j)
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -18,13 +18,18 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
pm_preprocess = PatternMatcher([
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
if SPEC: type_verify(sink, kernel_spec)
# preprocess
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
sink = graph_rewrite(sink, pm_preprocess+pm_mops, name="early movement ops")
# first we optimize
if optimize:

View File

@@ -80,7 +80,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
subs = {}
for r in s_topo:
# look for local INDEXes that are not used in the GLOBAL store, then add them as an INVALID
if r.op is Ops.STORE and r.src[0].src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
if r.op is Ops.STORE and r.buf_target().ptrdtype.addrspace == AddrSpace.GLOBAL:
idx = r.src[0]
missing_locals = [all_ranges[rng] for rng in local_dims if all_ranges[rng] not in idx.ranges]
if len(missing_locals):

View File

@@ -76,6 +76,9 @@ def do_contract(con:UOp):
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
expander = PatternMatcher([
# push broadcast through AFTER
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
# BUFFERIZE puts UNROLLs for ranges as contract
(UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))),

View File

@@ -18,8 +18,6 @@ sys.setrecursionlimit(10000)
pm_mops = PatternMatcher([
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
])
# *****************

View File

@@ -578,6 +578,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
return s
def buf_target(self) -> UOp:
# the buffer that's being loaded from or store to
match self.op:
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
case Ops.VECTORIZE:
assert all_same(self.src)
return self.src[0].buf_target()
case _: raise RuntimeError(f"buf_target called on non load/index/store {self.op}")
@property
def buffer(self) -> Buffer|MultiBuffer:
from tinygrad.device import Buffer, MultiBuffer
@@ -750,10 +760,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
if len(shape) > 1: ret = ret.reshape(shape)
return ret
def placeholder_like(self, slot:int):
assert all_int(self.shape), "no placeholder-like on symbolic shape"
return UOp.placeholder(self.dtype, self.shape, slot)
# set is store+after
def set(self:UOp, val:UOp|ConstType):
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val))
# set is store+end+after
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp:
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val).end(*argfix(end)))
@dataclass(frozen=True)
class KernelInfo:

View File

@@ -502,7 +502,7 @@ pm_simplify_valid = PatternMatcher([
])
# this is symbolic 2.0
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP}
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
# LOAD/STORE -> NOOP
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),