mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove ScheduleItem and merge it with ExecItem (#13759)
* remove ExecItem and merge it with ScheduleItem * less diff * fix issues * min diff * don't change bufs in _lower * min diff * update * revert * fixes * diff
This commit is contained in:
@@ -9,7 +9,6 @@ import xml.etree.ElementTree as ET
|
||||
|
||||
from tinygrad import nn, Tensor, Device
|
||||
from tinygrad.helpers import get_single_element
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
@@ -118,7 +117,8 @@ if __name__ == "__main__":
|
||||
root = ET.fromstring(xml_str)
|
||||
|
||||
a = Tensor.empty(16)+1
|
||||
for si, ei in lower_schedule(a.schedule()):
|
||||
for ei in a.schedule():
|
||||
ei.lower()
|
||||
# get text
|
||||
_, hdr, _ = elf_loader(ei.prg.lib)
|
||||
text = get_single_element([x for x in hdr if x.name==".text"]).content
|
||||
|
||||
@@ -36,7 +36,7 @@ if __name__ == "__main__":
|
||||
for _ in range(run_count): tc = (a@b).realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
ei = ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer])
|
||||
ei = ExecItem(ast, [a.uop.buffer, b.uop.buffer, c.uop.buffer], prg=runner)
|
||||
with Context(DEBUG=2):
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
print(f"custom {(c-tc).square().mean().item()}")
|
||||
|
||||
@@ -147,7 +147,7 @@ def test_matmul(sink:UOp, N=N):
|
||||
hc = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, hc)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
|
||||
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
|
||||
|
||||
ets = []
|
||||
with Context(DEBUG=2):
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import dtypes, Tensor
|
||||
from tinygrad.helpers import getenv, get_single_element
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad.codegen.opt import OptOps
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
|
||||
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
|
||||
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
|
||||
@@ -40,8 +39,8 @@ if __name__ == "__main__":
|
||||
|
||||
if getenv("SHOULD_USE_TC"):
|
||||
sched = a.matmul(b, dtype=acc_dtype).schedule()
|
||||
lowered = list(lower_schedule(sched))
|
||||
ei = get_single_element(lowered)[1]
|
||||
ei = get_single_element(sched)
|
||||
ei.lower()
|
||||
assert any(opt.op is OptOps.TC for opt in ei.prg.p.applied_opts), f"TC not triggered, {ei.prg.p.applied_opts}"
|
||||
|
||||
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
|
||||
@@ -33,5 +33,5 @@ if __name__ == "__main__":
|
||||
new_src = prg.src
|
||||
# can mod source here
|
||||
prg = replace(prg, src=new_src)
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
|
||||
for i in range(5): ei.run(wait=True)
|
||||
|
||||
@@ -88,7 +88,7 @@ if __name__ == "__main__":
|
||||
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
|
||||
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
|
||||
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
|
||||
tflops = []
|
||||
for i in range(5):
|
||||
tm = ei.run(wait=True)
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
import numpy as np
|
||||
import ctypes
|
||||
from tinygrad import Tensor, GlobalCounters, Context
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from tinygrad.device import CPUProgram
|
||||
from dataclasses import replace
|
||||
from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN
|
||||
|
||||
# only the memory access, over 100 GB/s! (sometimes)
|
||||
reduce_asm = """
|
||||
movi v0.2d, #0000000000000000
|
||||
mov w9, #0x30
|
||||
mov w10, #0x20
|
||||
mov x8, #-0x10
|
||||
movi v1.2d, #0000000000000000
|
||||
movk w9, #0x300, lsl #16
|
||||
movi v2.2d, #0000000000000000
|
||||
movk w10, #0x200, lsl #16
|
||||
movi v3.2d, #0000000000000000
|
||||
mov w11, #0x1000000
|
||||
mov w12, #0x3ffff0
|
||||
loop:
|
||||
ldp q4, q5, [x1]
|
||||
add x13, x1, x11
|
||||
add x15, x1, x10
|
||||
add x14, x1, x9
|
||||
add x8, x8, #0x10
|
||||
cmp x8, x12
|
||||
ldp q6, q7, [x1, #0x20]
|
||||
add x1, x1, #0x40
|
||||
ldp q4, q5, [x13]
|
||||
ldp q6, q7, [x13, #0x20]
|
||||
ldp q4, q5, [x15, #-0x20]
|
||||
ldp q6, q7, [x15]
|
||||
ldp q4, q5, [x14, #-0x30]
|
||||
ldp q6, q7, [x14, #-0x10]
|
||||
b.lo loop
|
||||
fadd v0.4s, v1.4s, v0.4s
|
||||
fadd v0.4s, v2.4s, v0.4s
|
||||
fadd v0.4s, v3.4s, v0.4s
|
||||
dup v1.4s, v0.s[1]
|
||||
dup v2.4s, v0.s[2]
|
||||
fadd v1.4s, v0.4s, v1.4s
|
||||
dup v0.4s, v0.s[3]
|
||||
fadd v1.4s, v2.4s, v1.4s
|
||||
fadd v0.4s, v0.4s, v1.4s
|
||||
str s0, [x0]
|
||||
ret
|
||||
"""
|
||||
|
||||
ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN)
|
||||
arm_bytecode, _ = ks.asm(reduce_asm)
|
||||
arm_bytecode = bytes(arm_bytecode)
|
||||
|
||||
reduce_src = """
|
||||
// data1 is 16M inputs
|
||||
typedef float float4 __attribute__((aligned(32),vector_size(16)));
|
||||
void reduce(float* restrict data0, float* restrict data1) {
|
||||
float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f};
|
||||
float* data1_1 = data1+4194304;
|
||||
float* data1_2 = data1+(4194304*2);
|
||||
float* data1_3 = data1+(4194304*3);
|
||||
for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) {
|
||||
float4 val0 = *(float4*)((data1+(ridx0+0)));
|
||||
float4 val1 = *(float4*)((data1+(ridx0+4)));
|
||||
float4 val2 = *(float4*)((data1+(ridx0+8)));
|
||||
float4 val3 = *(float4*)((data1+(ridx0+12)));
|
||||
acc0 += val0;
|
||||
acc1 += val1;
|
||||
acc2 += val2;
|
||||
acc3 += val3;
|
||||
val0 = *(float4*)((data1_1+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_1+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_1+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_1+(ridx0+12)));
|
||||
acc4 += val0;
|
||||
acc5 += val1;
|
||||
acc6 += val2;
|
||||
acc7 += val3;
|
||||
val0 = *(float4*)((data1_2+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_2+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_2+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_2+(ridx0+12)));
|
||||
acc0 += val0;
|
||||
acc1 += val1;
|
||||
acc2 += val2;
|
||||
acc3 += val3;
|
||||
val0 = *(float4*)((data1_3+(ridx0+0)));
|
||||
val1 = *(float4*)((data1_3+(ridx0+4)));
|
||||
val2 = *(float4*)((data1_3+(ridx0+8)));
|
||||
val3 = *(float4*)((data1_3+(ridx0+12)));
|
||||
acc4 += val0;
|
||||
acc5 += val1;
|
||||
acc6 += val2;
|
||||
acc7 += val3;
|
||||
}
|
||||
float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7;
|
||||
*(data0+0) = out[0]+out[1]+out[2]+out[3];
|
||||
}
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize()
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
# TODO: make it easy to alter the OptOps for a ScheduleItem
|
||||
GlobalCounters.reset()
|
||||
out = a.sum()
|
||||
sis = out.schedule()
|
||||
for i,(_,ei) in enumerate(lower_schedule(sis)):
|
||||
if i == 0:
|
||||
# change the source code
|
||||
prg_spec = ei.prg.p
|
||||
prg_spec = replace(prg_spec, name="reduce", src=reduce_src)
|
||||
prg = CompiledRunner(prg_spec)
|
||||
# change the assembly
|
||||
#prg._prg = CPUProgram(prg_spec.name, arm_bytecode)
|
||||
print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf)))
|
||||
ei = replace(ei, prg=prg)
|
||||
ei.run()
|
||||
print(out.item())
|
||||
np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4)
|
||||
@@ -1,100 +0,0 @@
|
||||
import itertools
|
||||
import numpy as np
|
||||
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
|
||||
from tinygrad.uop.ops import Ops, UOp
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
ctx_vars = { MULTIOUTPUT: (0, 1) }
|
||||
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
|
||||
|
||||
def fuzz_schedule(outs:List[UOp]):
|
||||
# find toposorts across all tunable params
|
||||
unique_ts: Dict[Tuple[LBScheduleItem, ...], Dict[str, int]] = {}
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
for var, val in zip(ctx_vars, combination): var.value = val
|
||||
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
|
||||
graph, in_degree, _ = _graph_schedule(outs)
|
||||
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
|
||||
toposorts = list(unique_ts.items())
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
|
||||
|
||||
# setup ground truth
|
||||
ground_truth: Dict[UOp, memoryview] = {}
|
||||
assign_targets: Dict[UOp, UOp] = {}
|
||||
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
|
||||
prerealized: Dict[UOp, memoryview] = {}
|
||||
seed = Tensor._seed
|
||||
ts,_ = toposorts[0]
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is Ops.ASSIGN:
|
||||
prerealized[out] = out.buffer.as_buffer()
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in lsi.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
||||
|
||||
# exec and validate each permutation with new Buffers
|
||||
for i, (ts, ctx) in enumerate(toposorts[1:]):
|
||||
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
||||
rawbufs: Dict[UOp, Buffer] = {}
|
||||
for lsi in ts:
|
||||
for out in lsi.outputs:
|
||||
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
|
||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
|
||||
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in lsi.inputs:
|
||||
if x not in rawbufs:
|
||||
# override the assign_target after ASSIGN
|
||||
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
|
||||
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
||||
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
|
||||
except Exception as e:
|
||||
print(f"FAILED FOR {out}")
|
||||
raise e
|
||||
|
||||
def _exec_si(si:ScheduleItem, seed:int):
|
||||
ei = lower_schedule_item(si)
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
ei.run()
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
|
||||
visited: Set[T] = set()
|
||||
ret: List[Tuple[T, ...]] = []
|
||||
path: List[T] = []
|
||||
|
||||
def recurse_paths(path:List[T]):
|
||||
for v, d in in_degree.items():
|
||||
if d != 0 or v in visited: continue
|
||||
for u in graph[v]: in_degree[u] -= 1
|
||||
path.append(v)
|
||||
visited.add(v)
|
||||
recurse_paths(path)
|
||||
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
|
||||
# backtrack
|
||||
for u in graph[v]: in_degree[u] += 1
|
||||
path.pop()
|
||||
visited.remove(v)
|
||||
if len(path) == len(in_degree): ret.append(tuple(path))
|
||||
recurse_paths(path)
|
||||
|
||||
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
|
||||
# verify all paths are unique
|
||||
assert len(ret) == len(set(ret))
|
||||
return ret
|
||||
Reference in New Issue
Block a user