mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
remove ScheduleItem and merge it with ExecItem (#13759)
* remove ExecItem and merge it with ScheduleItem * less diff * fix issues * min diff * don't change bufs in _lower * min diff * update * revert * fixes * diff
This commit is contained in:
@@ -88,7 +88,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
|
||||
## Debugging Tips
|
||||
|
||||
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
|
||||
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
|
||||
2. **Check schedule**: `tensor.schedule()` returns list of ExecItems
|
||||
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
|
||||
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`
|
||||
|
||||
|
||||
@@ -38,25 +38,19 @@ optim.schedule_step() # this will step the optimizer without running realize
|
||||
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
|
||||
# l1.uop and l2.uop define a computation graph
|
||||
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
schedule: List[ExecItem] = Tensor.schedule(l1, l2)
|
||||
|
||||
print(f"The schedule contains {len(schedule)} items.")
|
||||
for si in schedule: print(str(si)[:80])
|
||||
|
||||
# *****
|
||||
# 4. Lower a schedule.
|
||||
# 4. Lower and run the schedule.
|
||||
|
||||
from tinygrad.engine.realize import lower_schedule_item, ExecItem
|
||||
lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)]
|
||||
for si in tqdm(schedule): si.run()
|
||||
|
||||
# *****
|
||||
# 5. Run the schedule
|
||||
|
||||
for ei in tqdm(lowered): ei.run()
|
||||
|
||||
# *****
|
||||
# 6. Print the weight change
|
||||
# 5. Print the weight change
|
||||
|
||||
print("first weight change\n", l1.numpy()-l1n)
|
||||
print("second weight change\n", l2.numpy()-l2n)
|
||||
|
||||
@@ -17,15 +17,15 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al
|
||||
|
||||
## Scheduling
|
||||
|
||||
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||
|
||||
::: tinygrad.engine.schedule.ScheduleItem
|
||||
::: tinygrad.engine.schedule.ExecItem
|
||||
|
||||
## Lowering
|
||||
|
||||
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
|
||||
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with
|
||||
|
||||
::: tinygrad.engine.realize.lower_schedule
|
||||
::: tinygrad.engine.realize.run_schedule
|
||||
|
||||
There's a ton of complexity hidden behind this, see the `codegen/` directory.
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
from tinygrad import nn, Tensor, Device
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
@@ -118,7 +117,8 @@ if __name__ == "__main__":
|
||||
root = ET.fromstring(xml_str)
|
||||
|
||||
a = Tensor.empty(16)+1
|
||||
for si, ei in lower_schedule(a.schedule()):
|
||||
for ei in a.schedule():
|
||||
ei.lower()
|
||||
# get text
|
||||
_, hdr, _ = elf_loader(ei.prg.lib)
|
||||
text = get_single_element([x for x in hdr if x.name==".text"]).content
|
||||
|
||||
@@ -36,7 +36,7 @@ if __name__ == "__main__":
|
||||
for _ in range(run_count): tc = (a@b).realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
ei = ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer])
|
||||
ei = ExecItem(ast, [a.uop.buffer, b.uop.buffer, c.uop.buffer], prg=runner)
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
print(f"custom {(c-tc).square().mean().item()}")
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_matmul(sink:UOp, N=N):
|
||||
hc = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
|
||||
|
||||
ets = []
|
||||
with Context(DEBUG=2):
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, get_single_element
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad.codegen.opt import OptOps
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
|
||||
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
||||
@@ -40,8 +39,8 @@ if __name__ == "__main__":
|
||||
|
||||
if getenv("SHOULD_USE_TC"):
|
||||
sched = a.matmul(b, dtype=acc_dtype).schedule()
|
||||
lowered = list(lower_schedule(sched))
|
||||
ei = get_single_element(lowered)[1]
|
||||
ei = get_single_element(sched)
|
||||
ei.lower()
|
||||
assert any(opt.op is OptOps.TC for opt in ei.prg.p.applied_opts), f"TC not triggered, {ei.prg.p.applied_opts}"
|
||||
|
||||
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
|
||||
@@ -33,5 +33,5 @@ if __name__ == "__main__":
|
||||
new_src = prg.src
|
||||
# can mod source here
|
||||
prg = replace(prg, src=new_src)
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
|
||||
for i in range(5): ei.run(wait=True)
|
||||
|
||||
@@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
|
||||
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
|
||||
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
|
||||
tflops = []
|
||||
for i in range(5):
|
||||
tm = ei.run(wait=True)
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import numpy as np
|
||||
import ctypes
|
||||
from tinygrad import Tensor, GlobalCounters, Context
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from tinygrad.device import CPUProgram
|
||||
from dataclasses import replace
|
||||
from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN
|
||||
|
||||
# only the memory access, over 100 GB/s! (sometimes)
|
||||
reduce_asm = """
|
||||
movi v0.2d, #0000000000000000
|
||||
mov w9, #0x30
|
||||
mov w10, #0x20
|
||||
mov x8, #-0x10
|
||||
movi v1.2d, #0000000000000000
|
||||
movk w9, #0x300, lsl #16
|
||||
movi v2.2d, #0000000000000000
|
||||
movk w10, #0x200, lsl #16
|
||||
movi v3.2d, #0000000000000000
|
||||
mov w11, #0x1000000
|
||||
mov w12, #0x3ffff0
|
||||
loop:
|
||||
ldp q4, q5, [x1]
|
||||
add x13, x1, x11
|
||||
add x15, x1, x10
|
||||
add x14, x1, x9
|
||||
add x8, x8, #0x10
|
||||
cmp x8, x12
|
||||
ldp q6, q7, [x1, #0x20]
|
||||
add x1, x1, #0x40
|
||||
ldp q4, q5, [x13]
|
||||
ldp q6, q7, [x13, #0x20]
|
||||
ldp q4, q5, [x15, #-0x20]
|
||||
ldp q6, q7, [x15]
|
||||
ldp q4, q5, [x14, #-0x30]
|
||||
ldp q6, q7, [x14, #-0x10]
|
||||
b.lo loop
|
||||
fadd v0.4s, v1.4s, v0.4s
|
||||
fadd v0.4s, v2.4s, v0.4s
|
||||
fadd v0.4s, v3.4s, v0.4s
|
||||
dup v1.4s, v0.s[1]
|
||||
dup v2.4s, v0.s[2]
|
||||
fadd v1.4s, v0.4s, v1.4s
|
||||
dup v0.4s, v0.s[3]
|
||||
fadd v1.4s, v2.4s, v1.4s
|
||||
fadd v0.4s, v0.4s, v1.4s
|
||||
str s0, [x0]
|
||||
ret
|
||||
"""
|
||||
|
||||
ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN)
|
||||
arm_bytecode, _ = ks.asm(reduce_asm)
|
||||
arm_bytecode = bytes(arm_bytecode)
|
||||
|
||||
reduce_src = """
|
||||
// data1 is 16M inputs
|
||||
typedef float float4 __attribute__((aligned(32),vector_size(16)));
|
||||
void reduce(float* restrict data0, float* restrict data1) {
|
||||
float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float* data1_1 = data1+4194304;
|
||||
float* data1_2 = data1+(4194304*2);
|
||||
float* data1_3 = data1+(4194304*3);
|
||||
for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) {
|
||||
float4 val0 = *(float4*)((data1+(ridx0+0)));
|
||||
float4 val1 = *(float4*)((data1+(ridx0+4)));
|
||||
float4 val2 = *(float4*)((data1+(ridx0+8)));
|
||||
float4 val3 = *(float4*)((data1+(ridx0+12)));
|
||||
acc0 += val0;
|
||||
acc1 += val1;
|
||||
acc2 += val2;
|
||||
acc3 += val3;
|
||||
val0 = *(float4*)((data1_1+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_1+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_1+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_1+(ridx0+12)));
|
||||
acc4 += val0;
|
||||
acc5 += val1;
|
||||
acc6 += val2;
|
||||
acc7 += val3;
|
||||
val0 = *(float4*)((data1_2+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_2+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_2+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_2+(ridx0+12)));
|
||||
acc0 += val0;
|
||||
acc1 += val1;
|
||||
acc2 += val2;
|
||||
acc3 += val3;
|
||||
val0 = *(float4*)((data1_3+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_3+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_3+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_3+(ridx0+12)));
|
||||
acc4 += val0;
|
||||
acc5 += val1;
|
||||
acc6 += val2;
|
||||
acc7 += val3;
|
||||
}
|
||||
float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7;
|
||||
*(data0+0) = out[0]+out[1]+out[2]+out[3];
|
||||
}
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize()
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
# TODO: make it easy to alter the OptOps for a ScheduleItem
|
||||
GlobalCounters.reset()
|
||||
out = a.sum()
|
||||
sis = out.schedule()
|
||||
for i,(_,ei) in enumerate(lower_schedule(sis)):
|
||||
if i == 0:
|
||||
# change the source code
|
||||
prg_spec = ei.prg.p
|
||||
prg_spec = replace(prg_spec, name="reduce", src=reduce_src)
|
||||
prg = CompiledRunner(prg_spec)
|
||||
# change the assembly
|
||||
#prg._prg = CPUProgram(prg_spec.name, arm_bytecode)
|
||||
print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf)))
|
||||
ei = replace(ei, prg=prg)
|
||||
ei.run()
|
||||
print(out.item())
|
||||
np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4)
|
||||
@@ -1,100 +0,0 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
|
||||
|
||||
def fuzz_schedule(outs:List[UOp]):
|
||||
# find toposorts across all tunable params
|
||||
unique_ts: Dict[Tuple[LBScheduleItem, ...], Dict[str, int]] = {}
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
for var, val in zip(ctx_vars, combination): var.value = val
|
||||
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
|
||||
graph, in_degree, _ = _graph_schedule(outs)
|
||||
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
|
||||
toposorts = list(unique_ts.items())
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
|
||||
|
||||
# setup ground truth
|
||||
ground_truth: Dict[UOp, memoryview] = {}
|
||||
assign_targets: Dict[UOp, UOp] = {}
|
||||
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
|
||||
prerealized: Dict[UOp, memoryview] = {}
|
||||
seed = Tensor._seed
|
||||
ts,_ = toposorts[0]
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is Ops.ASSIGN:
|
||||
prerealized[out] = out.buffer.as_buffer()
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in lsi.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
||||
|
||||
# exec and validate each permutation with new Buffers
|
||||
for i, (ts, ctx) in enumerate(toposorts[1:]):
|
||||
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
||||
rawbufs: Dict[UOp, Buffer] = {}
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
|
||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
|
||||
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in lsi.inputs:
|
||||
if x not in rawbufs:
|
||||
# override the assign_target after ASSIGN
|
||||
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
|
||||
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
||||
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
|
||||
except Exception as e:
|
||||
print(f"FAILED FOR {out}")
|
||||
raise e
|
||||
|
||||
def _exec_si(si:ScheduleItem, seed:int):
|
||||
ei = lower_schedule_item(si)
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
ei.run()
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
|
||||
visited: Set[T] = set()
|
||||
ret: List[Tuple[T, ...]] = []
|
||||
path: List[T] = []
|
||||
|
||||
def recurse_paths(path:List[T]):
|
||||
for v, d in in_degree.items():
|
||||
if d != 0 or v in visited: continue
|
||||
for u in graph[v]: in_degree[u] -= 1
|
||||
path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
visited.remove(v)
|
||||
if len(path) == len(in_degree): ret.append(tuple(path))
|
||||
recurse_paths(path)
|
||||
|
||||
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
|
||||
# verify all paths are unique
|
||||
assert len(ret) == len(set(ret))
|
||||
return ret
|
||||
@@ -25,4 +25,4 @@ if __name__ == "__main__":
|
||||
#k.apply_opt(Opt(OptOps.GROUP, 0, 32))
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
run = CompiledRunner(prg:=get_program(k.ast, k.opts, k.applied_opts))
|
||||
ExecItem(run, si.bufs).run()
|
||||
ExecItem(k.ast, list(si.bufs), prg=run).run()
|
||||
|
||||
9
test/external/external_test_hsa_driver.py
vendored
9
test/external/external_test_hsa_driver.py
vendored
@@ -1,10 +1,12 @@
|
||||
import ctypes, unittest
|
||||
from tinygrad.helpers import init_c_struct_t
|
||||
from tinygrad.device import Device, Buffer, BufferXfer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.support.hsa import AQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.engine.realize import BufferXfer
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
def get_hsa_inc_prog(dev, inc=1):
|
||||
prg = f"""
|
||||
@@ -102,7 +104,8 @@ class TestHSADriver(unittest.TestCase):
|
||||
test_buf1.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(1*4)))
|
||||
|
||||
jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
|
||||
jit_cache = [ExecItem(UOp(Ops.NOOP), [test_buf0, test_buf2], prg=BufferXfer(test_buf0.nbytes, test_buf0.device, test_buf2.device)),
|
||||
ExecItem(UOp(Ops.NOOP), [test_buf2, test_buf1], prg=BufferXfer(test_buf2.nbytes, test_buf2.device, test_buf1.device))]
|
||||
graph = HSAGraph(jit_cache, [], {})
|
||||
|
||||
for i in range(10000):
|
||||
|
||||
8
test/external/fuzz_graph.py
vendored
8
test/external/fuzz_graph.py
vendored
@@ -4,7 +4,9 @@ from tinygrad.device import Buffer, Device
|
||||
from tinygrad.helpers import Context, getenv, from_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner
|
||||
from tinygrad.engine.realize import BufferXfer, get_runner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.jit import apply_graph_to_jit
|
||||
|
||||
BUF_LEN = getenv("BUF_LEN", 128)
|
||||
@@ -35,13 +37,13 @@ def gen_kernel_ji(device, deps):
|
||||
assert len(deps) >= 2
|
||||
out = alloc_rawbuffer(device)
|
||||
prg = gen_prg(device, len(deps))
|
||||
return ExecItem(prg, [out] + deps)
|
||||
return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg)
|
||||
|
||||
def gen_copy_ji(device, deps):
|
||||
assert len(deps) == 1
|
||||
out = alloc_rawbuffer(device)
|
||||
prg = BufferXfer(deps[0].nbytes, device, deps[0].device)
|
||||
return ExecItem(prg, [out] + deps)
|
||||
return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg)
|
||||
|
||||
def gen_graph():
|
||||
input_buffers = []
|
||||
|
||||
@@ -3,7 +3,8 @@ import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
|
||||
from tinygrad.helpers import Context, getenv
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
@@ -14,7 +15,7 @@ class TestArange(unittest.TestCase):
|
||||
sched = tensor.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer)
|
||||
ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run()
|
||||
ExecItem(sched[-1].ast, [tensor.uop.buffer], prg=CompiledRunner(p)).run()
|
||||
np.testing.assert_equal(tensor.numpy(), desired)
|
||||
return p.estimates.ops
|
||||
|
||||
|
||||
@@ -2,13 +2,12 @@ import unittest, io
|
||||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import OSX, CPU_LLVM, CPU_LVP
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
class TestCompileFailures(unittest.TestCase):
|
||||
def compile(self, out:Tensor):
|
||||
for _ in lower_schedule(out.schedule()): pass
|
||||
for si in out.schedule(): si.lower()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}")
|
||||
def test_interpolate_atari(self):
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.realize import lower_schedule_item, run_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
def test_contiguous_add(self):
|
||||
@@ -26,9 +26,9 @@ class TestFusionOp(unittest.TestCase):
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched = a.schedule()
|
||||
ei = lower_schedule_item(sched[-1])
|
||||
sched[-1].lower()
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
assert len(ei.prg.p.src.splitlines()) < 250
|
||||
assert len(sched[-1].prg.p.src.splitlines()) < 250
|
||||
|
||||
def test_recursive_add_cmp(self):
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -6,7 +6,9 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.helpers import Context, dedup, from_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner
|
||||
from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
from test.helpers import needs_second_gpu
|
||||
|
||||
@@ -27,11 +29,11 @@ def helper_exec_op(device, outbuf, inbufs):
|
||||
prg = get_runner(device, si.ast)
|
||||
cached_prgs[(device, len(inbufs))] = prg
|
||||
|
||||
return ExecItem(cached_prgs[(device, len(inbufs))], [outbuf] + inbufs)
|
||||
return ExecItem(UOp(Ops.NOOP), [outbuf] + inbufs, prg=cached_prgs[(device, len(inbufs))])
|
||||
|
||||
def helper_copy_op(device, dest, src):
|
||||
prg = BufferXfer(dest.nbytes, device, src.device)
|
||||
return ExecItem(prg, [dest, src])
|
||||
return ExecItem(UOp(Ops.NOOP), [dest, src], prg=prg)
|
||||
|
||||
def helper_alloc_rawbuffer(device, fill=False):
|
||||
rawbuf = Buffer(device, BUF_SIZE, dtypes.int).ensure_allocated()
|
||||
@@ -69,7 +71,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
|
||||
ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
|
||||
|
||||
# Build graphs
|
||||
gr_ji = [ExecItem(graph_impl(graph, [], {}), []) for graph in graphs]
|
||||
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs]
|
||||
|
||||
for _ in range(runs):
|
||||
test_bufs = helper_run_jit(gr_ji, bufs, out_buffers)
|
||||
|
||||
@@ -3,7 +3,6 @@ import numpy as np
|
||||
from tinygrad import Device, dtypes, Tensor, Context
|
||||
from tinygrad.device import LRUAllocator, is_dtype_supported
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
from test.helpers import REAL_DEV
|
||||
|
||||
@@ -129,8 +128,8 @@ class TestImageDType(unittest.TestCase):
|
||||
loss = x.image_dot(w1).image_dot(w2).float().max()
|
||||
loss.backward()
|
||||
sched = unwrap(w1.grad).schedule()
|
||||
for s,(_,ei) in zip(sched, lower_schedule(sched[:])):
|
||||
ei.run()
|
||||
for s in sched:
|
||||
s.run()
|
||||
if s.bufs[0].dtype == dtypes.float:
|
||||
lst = s.bufs[0].as_buffer().cast("f").tolist()
|
||||
print(lst)
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.codegen.gpudims import get_grouped_dims
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
|
||||
from tinygrad.device import Device, Buffer, is_dtype_supported
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
|
||||
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
|
||||
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
|
||||
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
@@ -25,9 +25,9 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
lowered = [x[1] for x in lower_schedule(c.schedule())]
|
||||
for ei in lowered: ei.run()
|
||||
rawbufs = lowered[-1].bufs
|
||||
sched = c.schedule()
|
||||
for si in sched: si.run()
|
||||
rawbufs = sched[-1].bufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
|
||||
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
|
||||
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.helpers import getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from test.helpers import not_support_multi_device, needs_second_gpu, slow
|
||||
@@ -123,9 +123,10 @@ class TestMultiTensor(unittest.TestCase):
|
||||
out = (X + X)
|
||||
sched = out.schedule()
|
||||
names = []
|
||||
for si, ei in lower_schedule(sched):
|
||||
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
|
||||
ei.run()
|
||||
for si in sched:
|
||||
si.lower()
|
||||
if isinstance(si.prg, CompiledRunner): names.append(si.prg.p.name)
|
||||
si.run()
|
||||
self.assertEqual(len(set(names)), 1, "function was relinearized")
|
||||
|
||||
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")
|
||||
|
||||
@@ -3,7 +3,8 @@ import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
|
||||
class TestOptGemm(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -18,7 +19,7 @@ class TestOptGemm(unittest.TestCase):
|
||||
# TODO: this should be a generic test helper
|
||||
si = get_single_element(t.schedule())
|
||||
run = CompiledRunner(get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts))
|
||||
ExecItem(run, si.bufs).run()
|
||||
ExecItem(si.ast, list(si.bufs), prg=run).run()
|
||||
test = si.bufs[0].numpy().reshape(self.res.shape)
|
||||
np.testing.assert_allclose(self.res, test, atol=1e-4)
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
|
||||
N = 512
|
||||
|
||||
@@ -42,8 +43,8 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
if replace_src is not None:
|
||||
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
|
||||
prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
new_si = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
|
||||
for _ in range(run_count): new_si.run(wait=True)
|
||||
|
||||
def get_quantized_model(sz):
|
||||
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
|
||||
@@ -74,8 +75,8 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
|
||||
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
|
||||
with Context(QUANTIZE=1):
|
||||
sched = run_onnx({"input":inp})["output"].schedule()
|
||||
ei = lower_schedule_item(sched[-2])
|
||||
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG]
|
||||
sched[-2].lower()
|
||||
daccs = [u for u in sched[-2].prg.p.uops if u.op is Ops.DEFINE_REG]
|
||||
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
|
||||
@@ -4,7 +4,7 @@ from functools import partial
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable
|
||||
from tinygrad.helpers import getenv, CI, OSX
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer.ptx import PTXRenderer
|
||||
from tinygrad.renderer.nir import NIRRenderer
|
||||
from test.helpers import not_support_multi_device, needs_second_gpu
|
||||
@@ -103,10 +103,12 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "PTX and NIR use pointer arithmetic")
|
||||
def test_threefry_doesnt_use_long(self):
|
||||
for (_,ei) in lower_schedule(Tensor.rand(20).schedule()):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
for u in ei.prg.p.uops:
|
||||
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
|
||||
sched = Tensor.rand(20).schedule()
|
||||
for si in sched:
|
||||
si.lower()
|
||||
if isinstance(si.prg, CompiledRunner):
|
||||
for u in si.prg.p.uops:
|
||||
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {si.prg.p.name}")
|
||||
|
||||
def test_threefry_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
@@ -12,7 +12,6 @@ from tinygrad.uop.ops import UOp, Ops, python_alu
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
|
||||
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
|
||||
for x in inputs: x.realize()
|
||||
@@ -76,8 +75,8 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
for _ in range(5): ret = python_alu[op](ret, Tensor.empty(1, dtype=dtype))
|
||||
schedule = ret.schedule()
|
||||
assert len(schedule) == 1
|
||||
ei = lower_schedule_item(schedule[0])
|
||||
src = ei.prg.p.src
|
||||
schedule[0].lower()
|
||||
src = schedule[0].prg.p.src
|
||||
self.assertEqual("("*5 not in src, should_strip_paren)
|
||||
|
||||
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
|
||||
@@ -24,8 +24,9 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
||||
else:
|
||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||
sched = Tensor(t).schedule()
|
||||
# test lowering all the ScheduleItems to ExecItems
|
||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||
# test lowering all the ExecItems
|
||||
for si in sched: si.lower()
|
||||
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
|
||||
if kernel_cnt != allowed:
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
|
||||
if DEBUG >= 3:
|
||||
@@ -174,7 +175,7 @@ class TestSchedule(unittest.TestCase):
|
||||
child.realize()
|
||||
assert a.uop.is_realized
|
||||
|
||||
# NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated.
|
||||
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
|
||||
def test_childless_empty_never_allocates(self):
|
||||
a = Tensor.empty(10)
|
||||
a.realize()
|
||||
|
||||
@@ -3,7 +3,6 @@ import numpy as np
|
||||
from tinygrad import Tensor, GlobalCounters, Context, Device
|
||||
from tinygrad.dtype import DTypeLike, dtypes
|
||||
from tinygrad.helpers import DEBUG, get_single_element
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
@@ -27,7 +26,7 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te
|
||||
out = e.div(ss).reshape(x_in.shape)
|
||||
return out
|
||||
|
||||
def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run()
|
||||
def run_one_schedule_item(out): get_single_element(out.schedule()).run()
|
||||
|
||||
class TestFuse(unittest.TestCase):
|
||||
def _test_fuse(self, fxn, *args, atol=1e-6, allow_multiple=False, **kwargs):
|
||||
|
||||
@@ -9,7 +9,8 @@ from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
|
||||
from tinygrad.uop.spec import shared_spec
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -568,7 +569,7 @@ class TestZeroRange(unittest.TestCase):
|
||||
|
||||
class TestUOpPrograms(unittest.TestCase):
|
||||
def _run(self, prog:UOp, *tensors:Tensor):
|
||||
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
|
||||
ExecItem(prog, [t.uop.buffer for t in tensors], prg=get_runner(Device.DEFAULT, prog)).run(wait=True)
|
||||
|
||||
def test_simple(self):
|
||||
out = Tensor.empty(10,10,dtype=dtypes.int)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv, GlobalCounters, EMULATE
|
||||
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program
|
||||
from tinygrad.engine.realize import get_program
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
@@ -17,9 +18,8 @@ def flops_mem(uops, ignore_indexing=False):
|
||||
# **************** new FlopCounter ****************
|
||||
|
||||
def get_stats(x:Tensor):
|
||||
si = x.schedule()[-1]
|
||||
ei = lower_schedule_item(si)
|
||||
return ei.prg.estimates.ops, ei.prg.estimates.mem
|
||||
si = x.schedule()[-1].lower()
|
||||
return si.prg.estimates.ops, si.prg.estimates.mem
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does extra load/store for packed types")
|
||||
class TestMemoryCount(unittest.TestCase):
|
||||
|
||||
@@ -2,7 +2,8 @@ import unittest, math, time
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
@@ -64,7 +65,7 @@ class TestTK(unittest.TestCase):
|
||||
c = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
c = c.float()
|
||||
|
||||
@@ -113,7 +114,7 @@ class TestTK(unittest.TestCase):
|
||||
c = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
c = c.float()
|
||||
|
||||
@@ -149,7 +150,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -188,7 +189,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -230,7 +231,7 @@ class TestTK(unittest.TestCase):
|
||||
c = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, c, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, c, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
c = c.float()
|
||||
@@ -270,7 +271,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -307,7 +308,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -352,7 +353,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -397,7 +398,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, M, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -442,7 +443,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -487,7 +488,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, M, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -547,7 +548,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, BLOCK_SIZE, N, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -607,7 +608,7 @@ class TestTK(unittest.TestCase):
|
||||
b = Tensor.empty(1, 1, N, BLOCK_SIZE, dtype="float32")
|
||||
Tensor.realize(a, b)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
b = b.float()
|
||||
|
||||
@@ -717,7 +718,7 @@ class TestTK(unittest.TestCase):
|
||||
out = Tensor.empty(B, N, H, D, dtype=dtypes.bfloat16)
|
||||
Tensor.realize(q, k, v, out)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in (out, q, k, v)], prg=get_runner(Device.DEFAULT, sink))
|
||||
for _ in range(5):
|
||||
et = ei.run(wait=True)
|
||||
attn_flops = 2 * B * H * N * N * D + \
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, Buf
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.schedule.rangeify import mop_cleanup
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
class GraphException(Exception): pass
|
||||
@@ -31,7 +31,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
|
||||
# clear jit inputs to allow their memory to be freed/reused
|
||||
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
|
||||
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
|
||||
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner))
|
||||
max_batch_size *= 2
|
||||
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
|
||||
except GraphException as e:
|
||||
@@ -89,6 +89,7 @@ class GraphRunner(Runner):
|
||||
|
||||
estimates = Estimates()
|
||||
for j,ji in enumerate(jit_cache):
|
||||
assert ji.prg is not None
|
||||
estimates += ji.prg.estimates
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
|
||||
@@ -103,6 +104,7 @@ class GraphRunner(Runner):
|
||||
self.w_dependency_map: dict[int, Any] = {}
|
||||
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
|
||||
|
||||
assert jit_cache[0].prg is not None
|
||||
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
|
||||
|
||||
def updated_vars(self, var_vals: dict[str, int]):
|
||||
@@ -186,7 +188,7 @@ class CapturedJit(Generic[ReturnType]):
|
||||
def replan_buffers_memory_layout(self):
|
||||
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
|
||||
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
|
||||
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
|
||||
for old, new in asgn.items():
|
||||
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
|
||||
self.__post_init__()
|
||||
@@ -249,7 +251,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
return ret
|
||||
|
||||
def add(self, ei:ExecItem):
|
||||
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
|
||||
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
|
||||
|
||||
def reset(self):
|
||||
assert self.fxn is not None, "can't reset without function"
|
||||
@@ -320,8 +322,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
# Exclude buffers involved in transfer ops to preserve parallelism.
|
||||
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
|
||||
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
|
||||
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
|
||||
item.metadata, item.fixedvars) for item in jit_cache]
|
||||
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
|
||||
|
||||
input_replace = get_input_replace(jit_cache, input_buffers)
|
||||
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
from collections import defaultdict
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
|
||||
from tinygrad.uop.ops import Ops
|
||||
@@ -63,8 +63,8 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
|
||||
|
||||
return assigned
|
||||
|
||||
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
||||
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
|
||||
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
|
||||
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast, Generator, Callable
|
||||
from typing import cast, Callable
|
||||
import time, pprint, random, itertools, math
|
||||
from dataclasses import dataclass, replace, field
|
||||
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
|
||||
@@ -7,7 +7,6 @@ from tinygrad.helpers import unwrap
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.opt import Opt
|
||||
|
||||
@@ -171,15 +170,43 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
|
||||
# **************** lowering functions ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
|
||||
(UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: EncDec(encdec, ctx[0].nbytes, ctx[1].device)),
|
||||
])
|
||||
|
||||
@dataclass
|
||||
class ExecItem:
|
||||
prg: Runner
|
||||
bufs: list[Buffer|None]
|
||||
metadata: tuple[Metadata, ...]|None = None
|
||||
ast: UOp
|
||||
bufs: list[Buffer|None] = field(default_factory=list)
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
prg: Runner|None = None
|
||||
|
||||
def lower(self):
|
||||
"""Populate self.prg by lowering the AST."""
|
||||
if self.prg is not None: return self
|
||||
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"error lowering {self.ast.op}")
|
||||
print("tensor operations:")
|
||||
pprint.pprint(self.metadata, indent=2)
|
||||
raise e
|
||||
return self
|
||||
|
||||
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
|
||||
if self.prg is None: self.lower()
|
||||
assert self.prg is not None
|
||||
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
|
||||
bufs = [unwrap(x) for x in self.bufs] if jit else [unwrap(x).ensure_allocated() for x in self.bufs]
|
||||
# reorder bufs to match program globals if needed
|
||||
_bufs = [self.bufs[i] for i in self.prg.p.globals] if isinstance(self.prg, CompiledRunner) else self.bufs
|
||||
bufs = cast(list[Buffer], [unwrap(x) for x in _bufs] if jit else [unwrap(x).ensure_allocated() for x in _bufs])
|
||||
if PROFILE:
|
||||
payload = {"metadata":self.metadata, "var_vals":var_vals, "bufs":[b.trace_num for b in bufs], "name":self.prg.display_name}
|
||||
payload["outputs"], payload["inputs"] = (self.prg.p.outs, self.prg.p.ins) if isinstance(self.prg, CompiledRunner) else ([0], [1])
|
||||
@@ -206,48 +233,28 @@ class ExecItem:
|
||||
self.prg.first_run = False
|
||||
return et
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
|
||||
(UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: ((EncDec(encdec, ctx[0].nbytes, ctx[1].device)), list(ctx))),
|
||||
])
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
|
||||
|
||||
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
try: yield (si, lower_schedule_item(si))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2:
|
||||
print(f"error lowering {si.ast.op}")
|
||||
print("tensor operations:")
|
||||
pprint.pprint(si.metadata, indent=2)
|
||||
raise e
|
||||
|
||||
# **************** main run function ****************
|
||||
|
||||
capturing: list = [] # put classes with an add method in here
|
||||
|
||||
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
for si, ei in lower_schedule(schedule):
|
||||
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
|
||||
while len(schedule):
|
||||
ei = schedule.pop(0).lower()
|
||||
if len(capturing) and CAPTURING: capturing[0].add(ei)
|
||||
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
|
||||
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
|
||||
# copy in allocated buffers from the GPU
|
||||
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
|
||||
for cpu_b, gpu_b in zip(nb, si.bufs):
|
||||
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
bufs = [b for b in ei.bufs if b is not None]
|
||||
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
|
||||
for cpu_b, gpu_b in zip(nb, bufs):
|
||||
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
|
||||
|
||||
# run on GPU
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
|
||||
with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
|
||||
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
|
||||
import numpy as np
|
||||
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
assert nb[0] is not None
|
||||
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
|
||||
else:
|
||||
ei.run(var_vals, do_update_stats=do_update_stats)
|
||||
|
||||
@@ -1,25 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from typing import cast
|
||||
from dataclasses import dataclass, field
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
|
||||
|
||||
# **** ScheduleItem return type
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...] = ()
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
fixedvars: dict[str, int] = field(default_factory=dict)
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
|
||||
# **** schedule linearizer
|
||||
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
|
||||
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
with cpu_profile(TracingKey("toposort sched_sink")):
|
||||
# construct the KERNEL children graph based on assigns
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
@@ -70,7 +62,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
with cpu_profile(TracingKey("expand ranges")):
|
||||
pre_schedule: list[ScheduleItem] = []
|
||||
pre_schedule: list[ExecItem] = []
|
||||
buf_uops_list: list[UOp] = []
|
||||
sched_ptr = 0
|
||||
in_ranges: dict[UOp, int] = {}
|
||||
@@ -89,7 +81,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
|
||||
else:
|
||||
ast, buf_uops, metadata, fixedvars, bound_ranges = si
|
||||
fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
|
||||
pre_schedule.append(ScheduleItem(ast, (), metadata, fixedvars))
|
||||
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
|
||||
buf_uops_list.append(UOp.sink(*buf_uops))
|
||||
sched_ptr += 1
|
||||
return pre_schedule, UOp.sink(*buf_uops_list)
|
||||
@@ -131,9 +123,9 @@ pm_post_sched_cache = PatternMatcher([
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
|
||||
])
|
||||
|
||||
schedule_cache: dict[bytes, tuple[list[ScheduleItem], UOp]] = {}
|
||||
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
|
||||
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
|
||||
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
@@ -180,7 +172,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
|
||||
|
||||
# add bufs to pre_schedule
|
||||
schedule: list[ScheduleItem] = []
|
||||
schedule: list[ExecItem] = []
|
||||
for i, si in enumerate(pre_schedule):
|
||||
buf_uops = buf_uops_sink.src[i].src
|
||||
# create subbuffers if needed
|
||||
@@ -193,10 +185,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
|
||||
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
|
||||
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
|
||||
schedule.append(ScheduleItem(si.ast, bufs, si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
|
||||
else:
|
||||
# ONE -> ONE
|
||||
schedule.append(ScheduleItem(si.ast, cast(tuple[Buffer, ...], ubufs), si.metadata, si.fixedvars))
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.mixin import OpMixin
|
||||
from tinygrad.mixin.movement import _align_left
|
||||
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
|
||||
from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars
|
||||
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
@@ -236,7 +236,7 @@ class Tensor(OpMixin):
|
||||
"""
|
||||
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]:
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
|
||||
"""
|
||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||
|
||||
@@ -249,7 +249,7 @@ class Tensor(OpMixin):
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
return schedule, var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
def schedule(self, *lst:Tensor) -> list[ExecItem]:
|
||||
"""Creates the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst)
|
||||
assert len(var_vals) == 0
|
||||
|
||||
Reference in New Issue
Block a user