mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-05 05:04:27 -05:00
@@ -26,9 +26,9 @@ Transforms the ast into an optimized ast. This is where BEAM search and heuristi
|
||||
|
||||
## tinygrad/codegen
|
||||
|
||||
Transform the optimized ast into a linearized list of UOps.
|
||||
Transform the optimized ast into a linearized and rendered program.
|
||||
|
||||
::: tinygrad.codegen.full_rewrite
|
||||
::: tinygrad.codegen.full_rewrite_to_program
|
||||
options:
|
||||
members: false
|
||||
show_labels: false
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"$schema": "https://opencode.ai/config.json", "formatter": false}
|
||||
{"$schema": "https://opencode.ai/config.json", "formatter": false, "lsp": false}
|
||||
|
||||
14
test/external/external_benchmark_op_conv.py
vendored
14
test/external/external_benchmark_op_conv.py
vendored
@@ -1,10 +1,9 @@
|
||||
# ruff: noqa: E501 E712 F401
|
||||
from dataclasses import replace
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.uop.ops import UOp, AxisType, Ops, KernelInfo
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.opt import Opt, OptOps # pylint: disable=unused-import
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import ImageDType, Invalid
|
||||
@@ -86,16 +85,11 @@ def dm_conv_172():
|
||||
|
||||
ast = {143: vision_conv_143, 153: vision_conv_153, 172: dm_conv_172}[getenv("NUM", 143)]()
|
||||
|
||||
compiler = Device.default.compiler
|
||||
renderer = Device.default.renderer
|
||||
allocator = Device.default.allocator
|
||||
|
||||
uops = full_rewrite(ast, renderer)
|
||||
src = renderer.render(uops)
|
||||
|
||||
lib = compiler.compile(src)
|
||||
ps = ProgramSpec("conv", src, Device.DEFAULT, ast, uops)
|
||||
cr = CompiledRunner(ps, precompiled=lib)
|
||||
ps = get_program(ast, renderer)
|
||||
cr = CompiledRunner(replace(ps, device=Device.DEFAULT))
|
||||
|
||||
gs = sorted(dedup([u for u in ast.toposort() if u.op is Ops.DEFINE_GLOBAL]), key=lambda u: u.arg)
|
||||
# print(len(gs))
|
||||
|
||||
@@ -2,18 +2,27 @@ import os, time, struct, functools, unittest
|
||||
from typing import Any, Callable
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.engine.realize import Runner
|
||||
from tinygrad.engine.realize import Runner, get_program
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import T, CI
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.codegen import full_rewrite_to_sink, line_rewrite, pm_linearize_cleanups
|
||||
from tinygrad.codegen.late.linearizer import linearize
|
||||
|
||||
# decorator to skip slow tests by default, run with RUN_SLOW=1 to include them
|
||||
slow = unittest.skipUnless(os.getenv("RUN_SLOW"), "slow test, set RUN_SLOW=1 to run")
|
||||
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler
|
||||
|
||||
def get_uops(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""Extract linearized UOps from a sink. Test helper that only does linearization (no render)."""
|
||||
if ren is None: ren = Renderer()
|
||||
if sink.arg is None: sink = sink.replace(arg=KernelInfo())
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
return line_rewrite(linearize(full_sink), pm_linearize_cleanups)
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.replace(Tensor.empty(p.shape, device=p.device, dtype=p.dtype))
|
||||
@@ -51,13 +60,12 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
||||
bufs = []
|
||||
for buf_dt, data in inputs or []:
|
||||
bufs.append(buf:=allocator.alloc(len(data) * buf_dt.itemsize))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + buf_dt.fmt, *data)))
|
||||
allocator._copyin(buf, memoryview(struct.pack(str(len(data)) + (buf_dt.fmt or ""), *data)))
|
||||
g = UOp(Ops.DEFINE_GLOBAL, uop.dtype.ptr(), arg=0, src=())
|
||||
opts = PythonRenderer()
|
||||
lst = full_rewrite(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), opts)
|
||||
prog = PythonProgram("run", PythonCompiler().compile(opts.render(lst)))
|
||||
prg = get_program(UOp.store(g.index(UOp.const(dtypes.int, 0)), uop).sink(), PythonRenderer())
|
||||
prog = PythonProgram("run", PythonCompiler().compile(prg.src))
|
||||
prog(out_buf:=allocator.alloc(uop.dtype.itemsize), *bufs)
|
||||
return out_buf.cast(uop.dtype.fmt).tolist()[0]
|
||||
return out_buf.cast(uop.dtype.fmt or "").tolist()[0]
|
||||
|
||||
def not_support_multi_device():
|
||||
# CL and CUDA don't support multi device if in CI
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from tinygrad.device import Buffer, Device, is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ConstType
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import dedup, flatten, prod
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.wgsl import WGSLRenderer
|
||||
from tinygrad.runtime.ops_python import PythonRenderer
|
||||
from tinygrad.uop.ops import UOp, Ops, python_alu
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
|
||||
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
|
||||
def _test_uop_result(inputs:list[Tensor], prg, local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
# NOTE: we only toposort the stores
|
||||
uops: list[UOp] = []
|
||||
def _recursive_add(uop:UOp) -> list[UOp]: return flatten([_recursive_add(x) for x in uop.src])+[uop]
|
||||
uops = dedup(flatten(_recursive_add(st) for st in stores))
|
||||
uops = prg.uops
|
||||
outbufs = [Buffer(Device.DEFAULT, sz:=(1 if local_size is None else prod(local_size)), (dtype:=u.src[1].dtype), \
|
||||
initial_value=np.zeros(sz, dtype=_to_np_dtype(dtype)).data) for u in uops if u.op is Ops.STORE]
|
||||
inbufs = [x.uop.base.buffer for x in inputs]
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
ei = CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test",
|
||||
src, Device.DEFAULT, uops[-1], uops=uops, local_size=local_size))
|
||||
prg = replace(prg, device=Device.DEFAULT)
|
||||
if local_size is not None: prg = replace(prg, local_size=local_size)
|
||||
ei = CompiledRunner(prg)
|
||||
ei.exec(outbufs+inbufs)
|
||||
return [np.frombuffer(x.as_buffer(), _to_np_dtype(x.dtype)) for x in outbufs]
|
||||
|
||||
@@ -37,8 +33,8 @@ def _setup_and_test_alu(alu_op:Ops, input_val:ConstType, *alu_src_uops:UOp):
|
||||
alu = ld.alu(alu_op, *alu_src_uops)
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
return _test_uop_result([Tensor([input_val])], uops)[0]
|
||||
prg = get_program(sink, Device[Device.DEFAULT].renderer)
|
||||
return _test_uop_result([Tensor([input_val])], prg)[0]
|
||||
|
||||
class TestRendererFailures(unittest.TestCase):
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
@@ -47,8 +43,8 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu = (lidx0:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 4),), 'lidx0')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0.valid(gate_alu)), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
prg = get_program(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, PythonRenderer)), "test is for ptx or python renderer")
|
||||
@@ -58,8 +54,8 @@ class TestRendererFailures(unittest.TestCase):
|
||||
gate_alu_1 = (lidx1:=UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 2),), 'lidx1')).ne(0)
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index((lidx0+lidx1*4).valid(gate_alu_0&gate_alu_1)), UOp.const(dtypes.int, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 2, 1])[0]
|
||||
prg = get_program(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], prg, local_size=[4, 2, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 0, 0, 0, 0, 1, 1, 1])
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
|
||||
@@ -104,8 +100,8 @@ class TestPTXFailures(unittest.TestCase):
|
||||
if_uop = UOp(Ops.IF, dtypes.void, (gate_alu,))
|
||||
gated_alu_store = UOp(Ops.STORE, dtypes.void, (a.index(lidx0, if_uop), val))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (gated_alu_store,))
|
||||
uops = full_rewrite(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
prg = get_program(sink, Device[Device.DEFAULT].renderer)
|
||||
ret = _test_uop_result([], prg, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -869,9 +869,8 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
for s in schedule:
|
||||
if s.ast.op is Ops.SINK:
|
||||
renderer = Device[s.bufs[0].device].renderer
|
||||
uops = full_rewrite(s.ast, renderer)
|
||||
renderer.render(uops)
|
||||
return uops
|
||||
prg = get_program(s.ast, renderer)
|
||||
return prg.uops
|
||||
|
||||
def _assert(self, dtype: DType, a: Tensor):
|
||||
uops = self._schedule_render(a)
|
||||
|
||||
@@ -7,30 +7,27 @@ from tinygrad.dtype import dtypes, DType, AddrSpace
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
|
||||
from tinygrad.uop.spec import shared_spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from test.helpers import get_uops
|
||||
from dataclasses import replace
|
||||
|
||||
def to_uops_list(u:list[UOp], ren=None) -> list[UOp]:
|
||||
sink = UOp.group(*u)
|
||||
for r in sink.ranges: sink = sink.end(r)
|
||||
# we strip the SINK here for legacy reasons
|
||||
ret = full_rewrite(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
|
||||
ret = get_uops(sink.sink(arg=KernelInfo(opts_to_apply=())), ren)
|
||||
assert ret[-1].op is Ops.SINK
|
||||
return ret[:-1]
|
||||
|
||||
def _uops_to_prg(uops_list):
|
||||
uops = full_rewrite(ast:=UOp.sink(*uops_list), ren=Device[Device.DEFAULT].renderer)
|
||||
src = Device[Device.DEFAULT].renderer.render(uops)
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
return CompiledRunner(ProgramSpec(uops[-1].arg.name if uops[-1].arg is not None else "test", src, Device.DEFAULT, ast, uops=uops,
|
||||
global_size=[1,1,1] if has_local else None, local_size=[1,1,1] if has_local else None))
|
||||
prg = get_program(UOp.sink(*uops_list), Device[Device.DEFAULT].renderer)
|
||||
return CompiledRunner(replace(prg, device=Device.DEFAULT))
|
||||
|
||||
def uop(uops:list[UOp], uop:Ops, dtype:Optional[DType], src:tuple[UOp, ...], arg:Any=None) -> UOp:
|
||||
uops.append(UOp(uop, dtype, tuple(src), arg))
|
||||
|
||||
@@ -4,7 +4,6 @@ from tinygrad.helpers import getenv, GlobalCounters, EMULATE
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
@@ -146,7 +145,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MUL, dtypes.int, (u1,u2))
|
||||
u5 = UOp(Ops.ADD, dtypes.int, (u4,u3))
|
||||
uops = full_rewrite(u5.sink())
|
||||
uops = list(u5.toposort())
|
||||
|
||||
globl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||
o1 = UOp(Ops.CONST, dtypes.int, tuple(), 1)
|
||||
@@ -155,7 +154,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||
u2 = globl.index(o2)
|
||||
u3 = UOp(Ops.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(Ops.MULACC, dtypes.int, (u1,u2,u3))
|
||||
uops_fma = full_rewrite(u4.sink())
|
||||
uops_fma = list(u4.toposort())
|
||||
|
||||
self.assertEqual(flops_mem(uops), flops_mem(uops_fma))
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import unittest, pickle, functools, math
|
||||
import z3
|
||||
|
||||
from tinygrad.dtype import dtypes, ConstType, DType, Invalid
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.helpers import Context
|
||||
from test.helpers import get_uops
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, sym_infer
|
||||
from tinygrad.uop.symbolic import sym, commutative, pm_simplify_valid
|
||||
from tinygrad.uop.validate import uops_to_z3
|
||||
@@ -747,7 +747,7 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
# TODO: copied from render, render does not support cast
|
||||
glbl = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||
uops = full_rewrite(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
uops = get_uops(UOp(Ops.STORE, dtypes.void, (glbl.index(UOp.const(dtypes.int, 0)), expr)).sink())
|
||||
rewritten_uop = [uop for uop in uops if uop.op is Ops.STORE][0].src[1]
|
||||
|
||||
self.assertEqual(rewritten_uop, cond.where(a.cast(dtypes.half), b.cast(dtypes.half)))
|
||||
|
||||
@@ -141,21 +141,3 @@ def full_rewrite_to_program(sink:UOp, ren:Renderer) -> UOp:
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
sink = UOp(Ops.PROGRAM, src=(full_sink,))
|
||||
return graph_rewrite(sink, pm_to_program, ctx=ren, name="linearize/render")
|
||||
|
||||
def full_rewrite(sink:UOp, ren:Renderer|None=None) -> list[UOp]:
|
||||
"""
|
||||
Function to transform the Kernel UOp graph into a linearized program.
|
||||
|
||||
Args:
|
||||
sink: The Ops.SINK rooting the Kernel graph.
|
||||
ren: The Renderer (can change how things are processed, fix this).
|
||||
|
||||
Returns:
|
||||
Linear program in UOps.
|
||||
"""
|
||||
|
||||
full_sink = full_rewrite_to_sink(sink, ren, optimize=sink.tag is None)
|
||||
assert len(full_sink.ranges) == 0, f"all ranges must end by the sink, {full_sink.ranges}"
|
||||
lst = line_rewrite(linearize(full_sink), pm_linearize_cleanups)
|
||||
if SPEC: type_verify(lst, program_spec)
|
||||
return lst
|
||||
|
||||
Reference in New Issue
Block a user