diff --git a/test/device/__init__.py b/test/device/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/unit/test_dtype_spec.py b/test/unit/test_dtype_spec.py index a3dfaea9e0..96c4e05089 100644 --- a/test/unit/test_dtype_spec.py +++ b/test/unit/test_dtype_spec.py @@ -8,7 +8,7 @@ from hypothesis import given, settings, strategies as strat import numpy as np import torch -settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) +settings.register_profile("my_profile", max_examples=50, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") core_dtypes = list(DTYPES_DICT.values()) diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py index fa9bf6d7c4..38f6877890 100644 --- a/test/unit/test_llm_server.py +++ b/test/unit/test_llm_server.py @@ -27,8 +27,8 @@ class TestLLMServer(unittest.TestCase): from tinygrad.apps.llm import Handler from tinygrad.helpers import TCPServerWithReuse - cls.port = 11435 - cls.server = TCPServerWithReuse(('127.0.0.1', cls.port), Handler) + cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler) + cls.port = cls.server.server_address[1] cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True) cls.server_thread.start() time.sleep(0.1) diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index d34d9cddc7..6d21cfbf40 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -11,8 +11,8 @@ class SimpleTokenizer: self._byte_decoder = {chr(b): b for b in bs} | {chr(256+i): b for i,b in enumerate(b for b in range(256) if b not in bs)} # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L286 - # TODO: ucat_range is slow - def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(sys.maxunicode + 1) if unicodedata.category(chr(cp)).startswith(pre)) + # 0x323b0 is one past the max codepoint in unicode categories L/N/Z (0x323af is max L) + def ucat_range(pre: str): return "".join(re.escape(chr(cp)) for cp in range(0x323b0) if unicodedata.category(chr(cp)).startswith(pre)) r_ws, r_p_N, r_p_L = r"\t\n\x0b\x0c\r\x85" + ucat_range("Z"), ucat_range("N"), ucat_range("L") self._split_to_word = re.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|" + \ f"[^\\r\\n{r_p_N}{r_p_L}]?[{r_p_L}]+|[{r_p_N}]{{1,3}}| ?[^{r_ws}{r_p_N}{r_p_L}]+[\\r\\n]*|[{r_ws}]*[\\r\\n]+|[{r_ws}]+(?![^{r_ws}])|[{r_ws}]+") diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a8c4a6845e..201077be2b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -90,23 +90,29 @@ from tinygrad.engine.memory import memory_planner from tinygrad.schedule.rangeify import get_rangeify_map from tinygrad.schedule.multi import get_multi_map -def replace_input_buffer(ctx:dict[UOp, UOp], b:UOp): - if (ret:=ctx.get(b, None)) is None: +def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp): + if (ret:=ctx[0].get(b, None)) is None: if b.op is Ops.BUFFER: - ctx[b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx)), b.src[1])) + ctx[0][b] = ret = b.replace(src=(UOp(Ops.LUNIQUE, arg=len(ctx[0])), b.src[1])) else: # TODO: flip args in CONST assert b.op is Ops.CONST - ctx[b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx)))) + ctx[0][b] = ret = b.replace(src=(b.src[0], UOp(Ops.LUNIQUE, arg=len(ctx[0])))) return ret +def strip_bind(ctx:tuple[dict[UOp, UOp], dict[str, int]], b:UOp): + var, val = b.src[0], b.src[1].arg + assert var.expr not in ctx[1] or ctx[1][var.expr] == val, f"bind mismatch on {var}, {ctx[1][var.expr]} != {val}" + ctx[1][var.expr] = val + return ctx[0].setdefault(b, b.replace(src=(b.src[0],))) + pm_pre_sched_cache = PatternMatcher([ # replace input buffers (UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), replace_input_buffer), # remove unique consts (UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE)), name="b"), replace_input_buffer), # strip value from BIND for cache key normalization, so different values hit same cache - (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), lambda ctx,b: ctx.setdefault(b, b.replace(src=(b.src[0],)))), + (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST)), name="b"), strip_bind), ]) def replace_input_buffer_back(ctx:dict[UOp, UOp], b:UOp): @@ -129,9 +135,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li # big_sink srcs are all the Tensors st = time.perf_counter() - # replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key + # replace all UNIQUE buffers with LUNIQUE, strip BIND values for cache key, extract var_vals input_buffers: dict[UOp, UOp] = {} - big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=input_buffers, name="rewrite for sched cache") + var_vals: dict[str, int] = {} + big_sink_cache = graph_rewrite(big_sink, pm_pre_sched_cache, ctx=(input_buffers, var_vals), name="rewrite for sched cache") sched_cache_key = big_sink_cache.key if (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: @@ -139,7 +146,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li if SPEC: type_verify(big_sink, tensor_spec) # hack to preserve metadata - graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx={}, name="preserve metadata") + graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}), name="preserve metadata") # tensor map is what we return tensor_map: dict[UOp, UOp] = {} @@ -191,17 +198,8 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars)) with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) - # extract var_vals from BINDs that were stripped (only if there are kernels) - var_vals: dict[str, int] = {} - if schedule: - for u in input_buffers: - if u.op is Ops.BIND: - var, val = u.unbind() - assert var.expr not in var_vals or var_vals[var.expr] == val, f"bind mismatch on {var}, {var_vals[var.expr]} != {val}" - var_vals[var.expr] = val - if (DEBUG >= 1 and len(schedule) > 1) or DEBUG >= 3: print(f"scheduled {len(schedule):4d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ f" | {' cache hit' if sc_ret is not None else 'CACHE MISS'} {sched_cache_key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache)} uops in cache") - return tensor_map, schedule, var_vals + return tensor_map, schedule, var_vals if schedule else {}