mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
rm run_schedule (#15847)
This commit is contained in:
@@ -25,7 +25,7 @@ The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/schedu
|
||||
|
||||
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with
|
||||
|
||||
::: tinygrad.engine.realize.run_schedule
|
||||
::: tinygrad.engine.realize.run_linear
|
||||
|
||||
There's a ton of complexity hidden behind this, see the `codegen/` directory.
|
||||
|
||||
|
||||
@@ -73,8 +73,9 @@ if __name__ == "__main__":
|
||||
|
||||
A, B = Tensor.normal(M, K, std=1e-1, dtype=dtypes.float16).realize(), Tensor.normal(K, N, std=1e-1, dtype=dtypes.float16).realize()
|
||||
C = A.matmul(B)
|
||||
sched = C.schedule()
|
||||
si = sched[-1]
|
||||
from tinygrad.schedule import linear_to_schedule
|
||||
linear, var_vals = C.linear_with_vars()
|
||||
si = linear_to_schedule(linear)[-1]
|
||||
|
||||
src = compiled.asm["ptx"]
|
||||
# specify the shared memory here so we don't need to do it dynamically
|
||||
@@ -97,10 +98,10 @@ if __name__ == "__main__":
|
||||
|
||||
# check correctness
|
||||
if getenv("VERIFY"):
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.realize import run_linear
|
||||
triton_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N)
|
||||
print(triton_buf)
|
||||
run_schedule(sched)
|
||||
run_linear(linear, var_vals)
|
||||
tinygrad_buf = np.frombuffer(si.bufs[0].as_memoryview(), np.float16).reshape(M,N)
|
||||
print(tinygrad_buf)
|
||||
np.testing.assert_allclose(triton_buf, tinygrad_buf)
|
||||
|
||||
@@ -191,34 +191,10 @@ class ExecItem:
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
# **************** main run function ****************
|
||||
# **************** run linear ****************
|
||||
|
||||
capturing: list = [] # put classes with an add_linear method in here
|
||||
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
while len(schedule):
|
||||
ei = schedule.pop(0).lower()
|
||||
sink = ei.ast
|
||||
if VALIDATE_WITH_CPU and sink.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||
for cpu_b, gpu_b in zip(nb, bufs):
|
||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_memoryview())
|
||||
|
||||
# run on GPU
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
ExecItem(sink, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# **************** run linear ****************
|
||||
|
||||
@dataclass
|
||||
class ExecContext:
|
||||
var_vals: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
Reference in New Issue
Block a user