diff --git a/docs/developer/developer.md b/docs/developer/developer.md index 14c0131ac6..2d4eb78241 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -25,7 +25,7 @@ The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedu The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with -::: tinygrad.engine.realize.run_schedule +::: tinygrad.engine.realize.run_linear There's a ton of complexity hidden behind this, see the `codegen/` directory. diff --git a/extra/gemm/triton_nv_matmul.py b/extra/gemm/triton_nv_matmul.py index 14be54a6ab..f6ee932641 100644 --- a/extra/gemm/triton_nv_matmul.py +++ b/extra/gemm/triton_nv_matmul.py @@ -73,8 +73,9 @@ if __name__ == "__main__": A, B = Tensor.normal(M, K, std=1e-1, dtype=dtypes.float16).realize(), Tensor.normal(K, N, std=1e-1, dtype=dtypes.float16).realize() C = A.matmul(B) - sched = C.schedule() - si = sched[-1] + from tinygrad.schedule import linear_to_schedule + linear, var_vals = C.linear_with_vars() + si = linear_to_schedule(linear)[-1] src = compiled.asm["ptx"] # specify the shared memory here so we don't need to do it dynamically @@ -97,10 +98,10 @@ if __name__ == "__main__": # check correctness if getenv("VERIFY"): - from tinygrad.engine.realize import run_schedule + from tinygrad.engine.realize import run_linear triton_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N) print(triton_buf) - run_schedule(sched) + run_linear(linear, var_vals) tinygrad_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N) print(tinygrad_buf) np.testing.assert_allclose(triton_buf, tinygrad_buf) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 3dac0698cf..5816a7c5a9 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -191,34 +191,10 @@ class ExecItem: self.prg.first_run = False return et -# **************** main run function **************** +# **************** run linear **************** capturing: list = [] # put classes with an add_linear method in here -def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True): - while len(schedule): - ei = schedule.pop(0).lower() - sink = ei.ast - if VALIDATE_WITH_CPU and sink.op is Ops.SINK: - # copy in allocated buffers from the GPU - bufs = [b for b in ei.bufs if b is not None] - nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs] - for cpu_b, gpu_b in zip(nb, bufs): - if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_memoryview()) - - # run on GPU - ei.run(var_vals, do_update_stats=do_update_stats) - - # validate the output buffers match (NOTE: this is assuming the output is buffer 0) - ExecItem(sink, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats) - import numpy as np - assert nb[0] is not None - np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) - else: - ei.run(var_vals, do_update_stats=do_update_stats) - -# **************** run linear **************** - @dataclass class ExecContext: var_vals: dict[str, int] = field(default_factory=dict)