mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
more uop programs (#13007)
* more uop program * test_matmul_relu * tests fix
This commit is contained in:
@@ -569,11 +569,21 @@ class TestUOpPrograms(unittest.TestCase):
|
||||
def _run(self, prog:UOp, *tensors:Tensor):
|
||||
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
|
||||
|
||||
def test_simple(self):
|
||||
out = Tensor.empty(10,10,dtype=dtypes.int)
|
||||
|
||||
ptr = UOp.placeholder(out.dtype, out.shape, slot=0)
|
||||
i, j = UOp.range(10, axis_id=0), UOp.range(10, axis_id=1)
|
||||
prog = ptr[i,j].set(42).end(i,j)
|
||||
self._run(prog.sink(), out)
|
||||
|
||||
with Context(DEBUG=0): self.assertTrue((out == 42).all().item())
|
||||
|
||||
def test_matmul(self):
|
||||
a = Tensor.rand(10,10)
|
||||
b = Tensor.rand(10,10)
|
||||
a = Tensor.randn(10,10)
|
||||
b = Tensor.randn(10,10)
|
||||
c = Tensor.empty(10,10)
|
||||
ref = a@b
|
||||
ref = (a@b)
|
||||
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
|
||||
|
||||
# C[i,j] = sum_k A[i,k] * B[k,j]
|
||||
@@ -581,16 +591,16 @@ class TestUOpPrograms(unittest.TestCase):
|
||||
M = N = K = 10
|
||||
DT = dtypes.float32
|
||||
|
||||
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
|
||||
i = UOp.range(M, axis_id=0) # rows of A/C
|
||||
j = UOp.range(N, axis_id=1) # cols of B/C
|
||||
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
|
||||
|
||||
# Placeholders (bind slots explicitly)
|
||||
A = UOp.placeholder(DT, (M, K), slot=0)
|
||||
B = UOp.placeholder(DT, (K, N), slot=1)
|
||||
C = UOp.placeholder(DT, (M, N), slot=2)
|
||||
|
||||
# Axes: i,j are spatial; k is a reduction axis over the shared dim K
|
||||
i = UOp.range(M, axis_id=0) # rows of A/C
|
||||
j = UOp.range(N, axis_id=1) # cols of B/C
|
||||
k = UOp.range(K, axis_id=2, axis_type=AxisType.REDUCE) # reduction over K
|
||||
|
||||
# Zero-init: write a scalar 0 to each (i,j).
|
||||
C = C[i, j].set(0.0)
|
||||
|
||||
@@ -601,9 +611,27 @@ class TestUOpPrograms(unittest.TestCase):
|
||||
prog = C.end(i, j, k)
|
||||
|
||||
# run program
|
||||
# TODO: make this work with opts_to_apply
|
||||
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
|
||||
|
||||
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
|
||||
|
||||
def test_matmul_relu(self):
|
||||
a, b, c = Tensor.randn(10,10), Tensor.randn(10,10), Tensor.empty(10,10)
|
||||
ref = (a@b).relu()
|
||||
with Context(DEBUG=0): Tensor.realize(a, b, c, ref)
|
||||
|
||||
A, B, C = a.uop.placeholder_like(0), b.uop.placeholder_like(1), c.uop.placeholder_like(2)
|
||||
i, j, k = UOp.range(10, 0), UOp.range(10, 1), UOp.range(10, 2, axis_type=AxisType.REDUCE)
|
||||
|
||||
C = C[i, j].set(0.0)
|
||||
C = C[i, j].set(C.after(k)[i, j] + A[i, k] * B[k, j], end=k)
|
||||
C = C[i, j].set(C[i, j].maximum(0.0))
|
||||
|
||||
prog = C.end(i, j)
|
||||
|
||||
self._run(prog.sink(arg=KernelInfo(opts_to_apply=())), a, b, c)
|
||||
with Context(DEBUG=0): self.assertLessEqual((c-ref).square().mean().item(), 1e-6)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -18,13 +18,18 @@ from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_s
|
||||
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops
|
||||
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
|
||||
|
||||
pm_preprocess = PatternMatcher([
|
||||
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
|
||||
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
if SPEC: type_verify(sink, kernel_spec)
|
||||
|
||||
# preprocess
|
||||
sink = graph_rewrite(sink, pm_mops, name="early movement ops")
|
||||
sink = graph_rewrite(sink, pm_preprocess+pm_mops, name="early movement ops")
|
||||
|
||||
# first we optimize
|
||||
if optimize:
|
||||
|
||||
@@ -80,7 +80,7 @@ def add_gpudims(ctx:Renderer, s:UOp):
|
||||
subs = {}
|
||||
for r in s_topo:
|
||||
# look for local INDEXes that are not used in the GLOBAL store, then add them as an INVALID
|
||||
if r.op is Ops.STORE and r.src[0].src[0].ptrdtype.addrspace == AddrSpace.GLOBAL:
|
||||
if r.op is Ops.STORE and r.buf_target().ptrdtype.addrspace == AddrSpace.GLOBAL:
|
||||
idx = r.src[0]
|
||||
missing_locals = [all_ranges[rng] for rng in local_dims if all_ranges[rng] not in idx.ranges]
|
||||
if len(missing_locals):
|
||||
|
||||
@@ -76,6 +76,9 @@ def do_contract(con:UOp):
|
||||
return UOp(Ops.UNROLL, con.dtype, (ex.src[0].gep(tuple(idxs)),), new_ex_args)
|
||||
|
||||
expander = PatternMatcher([
|
||||
# push broadcast through AFTER
|
||||
(UPat.var("x").broadcast(name="b").after(name="a", allow_any_len=True), lambda x,b,a: x.after(*a.src[1:]).broadcast(len(b.src))),
|
||||
(UPat.var("x").broadcast(name="b").end(name="a", allow_any_len=True), lambda x,b,a: x.end(*a.src[1:]).broadcast(len(b.src))),
|
||||
# BUFFERIZE puts UNROLLs for ranges as contract
|
||||
(UPat(Ops.BUFFERIZE, src=(UPat(Ops.UNROLL), UPat(Ops.UNROLL)), name="x"),
|
||||
lambda x: x.replace(src=tuple(UOp(Ops.CONTRACT, dtype=s.dtype.vec(x.src[1].src[0].dtype.count), src=(s,), arg=x.src[1].arg) for s in x.src))),
|
||||
|
||||
@@ -18,8 +18,6 @@ sys.setrecursionlimit(10000)
|
||||
pm_mops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"),
|
||||
lambda r,idx: r.src[0].index(*apply_movement_op(r.op, r.src[0].shape, r.marg, idx.src[1:]), dtype=idx.dtype, arg=idx.arg)), # type: ignore
|
||||
(UPat(Ops.RESHAPE, name="r").after(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:]).reshape(r.shape)),
|
||||
(UPat(Ops.RESHAPE, name="r").end(name="a", allow_any_len=True), lambda r,a: a.replace(src=(r.src[0],)+a.src[1:])),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
@@ -578,6 +578,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
while len(s.src) and s.op not in {Ops.BUFFER, Ops.BUFFERIZE, Ops.MSTACK}: s = s.src[0]
|
||||
return s
|
||||
|
||||
def buf_target(self) -> UOp:
|
||||
# the buffer that's being loaded from or store to
|
||||
match self.op:
|
||||
case Ops.DEFINE_GLOBAL | Ops.DEFINE_LOCAL | Ops.DEFINE_REG: return self
|
||||
case Ops.AFTER | Ops.INDEX | Ops.STORE | Ops.LOAD: return self.src[0].buf_target()
|
||||
case Ops.VECTORIZE:
|
||||
assert all_same(self.src)
|
||||
return self.src[0].buf_target()
|
||||
case _: raise RuntimeError(f"buf_target called on non load/index/store {self.op}")
|
||||
|
||||
@property
|
||||
def buffer(self) -> Buffer|MultiBuffer:
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
@@ -750,10 +760,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
ret = UOp(Ops.DEFINE_GLOBAL, dtype.ptr(prod(shape)), arg=slot)
|
||||
if len(shape) > 1: ret = ret.reshape(shape)
|
||||
return ret
|
||||
def placeholder_like(self, slot:int):
|
||||
assert all_int(self.shape), "no placeholder-like on symbolic shape"
|
||||
return UOp.placeholder(self.dtype, self.shape, slot)
|
||||
|
||||
# set is store+after
|
||||
def set(self:UOp, val:UOp|ConstType):
|
||||
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val))
|
||||
# set is store+end+after
|
||||
def set(self:UOp, val:UOp|ConstType, end:UOp|tuple[UOp, ...]=()) -> UOp:
|
||||
return self.src[0].after(self.store(UOp.const(self.dtype, val) if not isinstance(val, UOp) else val).end(*argfix(end)))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KernelInfo:
|
||||
|
||||
@@ -502,7 +502,7 @@ pm_simplify_valid = PatternMatcher([
|
||||
])
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP}
|
||||
REMOVE_FROM_SINK_LIKE = {Ops.UNROLL, Ops.NOOP, Ops.VECTORIZE, Ops.SINK}
|
||||
sym = symbolic_flat+pm_simplify_valid+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load(), allow_any_len=True), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
|
||||
Reference in New Issue
Block a user