mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
test fixups / speedups / var_vals refactor (#13812)
* no PYTHONPATH + llm server port 0 * llm tok speedup * refactor var_vals
This commit is contained in:
0
test/device/__init__.py
Normal file
0
test/device/__init__.py
Normal file
@@ -8,7 +8,7 @@ from hypothesis import given, settings, strategies as strat
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
|
||||
@@ -27,8 +27,8 @@ class TestLLMServer(unittest.TestCase):
|
||||
from tinygrad.apps.llm import Handler
|
||||
from tinygrad.helpers import TCPServerWithReuse
|
||||
|
||||
cls.port = 11435
|
||||
cls.server = TCPServerWithReuse(('127.0.0.1', cls.port), Handler)
|
||||
cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler)
|
||||
cls.port = cls.server.server_address[1]
|
||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||
cls.server_thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
@@ -11,8 +11,8 @@ class SimpleTokenizer:
|
||||
self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)}
|
||||
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286
|
||||
# TODO: ucat_range is slow
|
||||
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre))
|
||||
# 0x323b0 is one past the max codepoint in unicode categories L/N/Z (0x323af is max L)
|
||||
def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(0x323b0) if unicodedata.category(chr(cp)).startswith(pre))
|
||||
r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L")
|
||||
self._split_to_word = re.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \
|
||||
f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+")
|
||||
|
||||
@@ -90,23 +90,29 @@ from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||
from tinygrad.schedule.multi import get_multi_map
|
||||
|
||||
def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp):
|
||||
if (ret:=ctx.get(b, None)) is None:
|
||||
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
|
||||
if (ret:=ctx[0].get(b, None)) is None:
|
||||
if b.op is Ops.BUFFER:
|
||||
ctx[b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx)), b.src[1]))
|
||||
ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1]))
|
||||
else:
|
||||
# TODO: flip args in CONST
|
||||
assert b.op is Ops.CONST
|
||||
ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx))))
|
||||
ctx[0][b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx[0]))))
|
||||
return ret
|
||||
|
||||
def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp):
|
||||
var, val = b.src[0], b.src[1].arg
|
||||
assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}"
|
||||
ctx[1][var.expr] = val
|
||||
return ctx[0].setdefault(b, b.replace(src=(b.src[0],)))
|
||||
|
||||
pm_pre_sched_cache = PatternMatcher([
|
||||
# replace input buffers
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer),
|
||||
# remove unique consts
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer),
|
||||
# strip value from BIND for cache key normalization, so different values hit same cache
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), lambda ctx,b: ctx.setdefault(b, b.replace(src=(b.src[0],)))),
|
||||
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind),
|
||||
])
|
||||
|
||||
def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp):
|
||||
@@ -129,9 +135,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
# big_sink srcs are all the Tensors
|
||||
st = time.perf_counter()
|
||||
|
||||
# replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key
|
||||
# replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key, extract var_vals
|
||||
input_buffers: dict[UOp, UOp] = {}
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache")
|
||||
var_vals: dict[str, int] = {}
|
||||
big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals), name="rewrite for sched cache")
|
||||
sched_cache_key = big_sink_cache.key
|
||||
|
||||
if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
@@ -139,7 +146,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
# hack to preserve metadata
|
||||
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx={}, name="preserve metadata")
|
||||
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}), name="preserve metadata")
|
||||
|
||||
# tensor map is what we return
|
||||
tensor_map: dict[UOp, UOp] = {}
|
||||
@@ -191,17 +198,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
|
||||
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
|
||||
|
||||
# extract var_vals from BINDs that were stripped (only if there are kernels)
|
||||
var_vals: dict[str, int] = {}
|
||||
if schedule:
|
||||
for u in input_buffers:
|
||||
if u.op is Ops.BIND:
|
||||
var, val = u.unbind()
|
||||
assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}"
|
||||
var_vals[var.expr] = val
|
||||
|
||||
if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3:
|
||||
print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache)} uops in cache")
|
||||
return tensor_map, schedule, var_vals
|
||||
return tensor_map, schedule, var_vals if schedule else {}
|
||||
|
||||
Reference in New Issue
Block a user