diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index b1b3461f0d..3fdd1b8518 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -45,6 +45,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: Run Stable Diffusion run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt - name: Run Stable Diffusion with fp16 @@ -148,6 +150,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: Run model inference benchmark run: NV=1 NOCLANG=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch @@ -252,6 +256,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: Fuzz Padded Tensor Core GEMM (NV) run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py - name: Fuzz Padded Tensor Core GEMM (PTX) @@ -318,6 +324,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: Show off tinybox run: /opt/rocm/bin/rocm-bandwidth-test # TODO: unstable on AMD @@ -420,6 +428,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: Train MNIST run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps @@ -471,6 +481,8 @@ jobs: run: | echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal + - name: reset process replay + run: test/external/process_replay/reset.py - name: openpilot compile 0.9.4 run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt - name: openpilot compile 0.9.7 @@ -485,6 +497,8 @@ jobs: run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_4.txt - name: benchmark openpilot w IMAGE=2 0.9.7 run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt + - name: openpilot compile3 0.9.7 + run: PYTHONPATH="." QCOM=1 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b00ffb0faf..d34cac1a66 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -197,7 +197,7 @@ jobs: - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model compile and size run: | - PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py + PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000' - if: ${{ matrix.task == 'optimage' }} name: Test openpilot model correctness (float32) diff --git a/examples/openpilot/compile2.py b/examples/openpilot/compile2.py index f907097a60..811a5fc240 100644 --- a/examples/openpilot/compile2.py +++ b/examples/openpilot/compile2.py @@ -6,6 +6,7 @@ sys.path.insert(0, str(pathlib.Path(__file__).parents[1])) if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" +if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1" OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx" diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index a04c3c2652..5dbefc8398 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -3,6 +3,7 @@ import numpy as np if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1" if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2" if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1" +if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1" from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters from tinygrad.helpers import OSX, DEBUG, Timing diff --git a/examples/sdxl.py b/examples/sdxl.py index 5196430b83..131525b6ac 100644 --- a/examples/sdxl.py +++ b/examples/sdxl.py @@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin from examples.stable_diffusion import ResnetBlock, Mid import numpy as np -from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union +from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type import argparse, tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -282,19 +282,29 @@ class SDXL: return self.first_stage_model.decode(1.0 / 0.13025 * x) -class VanillaCFG: +class Guider(ABC): def __init__(self, scale:float): self.scale = scale - def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]: + @abstractmethod + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: + pass + +class VanillaCFG(Guider): + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: c_out = {} for k in c: assert k in ["vector", "crossattn", "concat"] c_out[k] = Tensor.cat(uc[k], c[k], dim=0) - return Tensor.cat(x, x), Tensor.cat(s, s), c_out - def __call__(self, x:Tensor) -> Tensor: - x_u, x_c = x.chunk(2) + x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2) + x_pred = x_u + self.scale*(x_c - x_u) + return x_pred + +class SplitVanillaCFG(Guider): + def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor: + x_u = denoiser(x, s, uc) + x_c = denoiser(x, s, c) x_pred = x_u + self.scale*(x_c - x_u) return x_pred @@ -302,13 +312,12 @@ class VanillaCFG: # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21 # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287 class DPMPP2MSampler: - def __init__(self, cfg_scale:float): + def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG): self.discretization = LegacyDDPMDiscretization() - self.guider = VanillaCFG(cfg_scale) + self.guider = guider_cls(cfg_scale) def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]: - denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc)) - denoised = self.guider(denoised) + denoised = self.guider(denoiser, x, sigma, c, uc) t, t_next = sigma.log().neg(), next_sigma.log().neg() h = t_next - t @@ -329,7 +338,7 @@ class DPMPP2MSampler: return x, denoised def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor: - sigmas = self.discretization(num_steps) + sigmas = self.discretization(num_steps).to(x.device) x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) diff --git a/extra/datasets/sops.gz b/extra/datasets/sops.gz index cf75102480..5de066b1a7 100644 Binary files a/extra/datasets/sops.gz and b/extra/datasets/sops.gz differ diff --git a/extra/models/unet.py b/extra/models/unet.py index fad41443cb..92d4496320 100644 --- a/extra/models/unet.py +++ b/extra/models/unet.py @@ -7,7 +7,7 @@ import math # https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207 def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): half = dim // 2 - freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp() + freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp() args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16) diff --git a/extra/onnx.py b/extra/onnx.py index 51d4600c65..de5925ce27 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto): return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims)) if len(inp.raw_data) > 0: data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy() - return Tensor(data, requires_grad=False).reshape(tuple(inp.dims)) + return Tensor(data.reshape(tuple(inp.dims)), requires_grad=False) return Tensor(None, requires_grad=False) def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]: diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 86ac60a92a..febf91d6fe 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -1,7 +1,7 @@ from typing import List from extra.models.resnet import ResNet50 from tinygrad import Tensor, Device -from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen +from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen, _CURRENT_KERNEL from tinygrad.ops import UOps from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.lowerer import ast_to_uop @@ -43,6 +43,7 @@ if __name__ == "__main__": rewritten_uops = [] for i,(k,u) in enumerate(zip(kernels, uops)): with Timing(f"rewrite {i:2d} {k.name}{' '*(50-ansilen(k.name))}", enabled=getenv("VERBOSE", 0)): + if getenv("VIZ"): _CURRENT_KERNEL.set(k.name) rewritten_uops.append(full_graph_rewrite(u, k.opts)) uops = rewritten_uops if getenv("LINEARIZE", 1): diff --git a/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5_float16.npy b/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5_float16.npy deleted file mode 100644 index b8382df3d0..0000000000 Binary files a/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5_float16.npy and /dev/null differ diff --git a/test/external/process_replay/diff_schedule.py b/test/external/process_replay/diff_schedule.py index 1639f32b3a..3cdbdffd72 100644 --- a/test/external/process_replay/diff_schedule.py +++ b/test/external/process_replay/diff_schedule.py @@ -17,7 +17,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List if not os.path.isfile(fp): shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp) # create the reference graph - ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs) + ref_graph, ref_in_degree, _ = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs) # compare diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)]) @@ -26,7 +26,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]] for _,in_degree in s: for lsi in in_degree: for buf in lsi.outputs: - si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)) + si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), tuple(lsi.metadata))) changed = 0 seen_diffs: Set[bytes] = set() for buf,si in si_for_buf.items(): diff --git a/test/external/process_replay/helpers.py b/test/external/process_replay/helpers.py index 399e797895..b8b9880cfb 100644 --- a/test/external/process_replay/helpers.py +++ b/test/external/process_replay/helpers.py @@ -1,5 +1,7 @@ -import difflib, logging -from tinygrad.helpers import colored, getenv +from dataclasses import dataclass +import difflib, logging, traceback, subprocess +from typing import Dict, Optional +from tinygrad.helpers import ContextVar, colored, getenv def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -10,3 +12,17 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)): import ocdiff diff = ocdiff.console_diff(str(s0), str(s1)) logging.info(diff) + +@dataclass(frozen=True) +class ProcessReplayContext: + ctx_vars: Dict[str, int] + loc: str = "" + head_sha: str = "" + run_id: Optional[int] = None +def get_process_replay_ctx() -> ProcessReplayContext: + stack = filter(lambda x: "tinygrad" in x.filename and not any(n in x.filename for n in ["engine/schedule.py", "engine/realize.py", \ + "codegen/kernel.py", "unittest"]), traceback.extract_stack()[:-1]) + loc = "\n".join(traceback.format_list(stack)) + try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode() + except Exception: head_sha = "" + return ProcessReplayContext({k:v.value for k,v in ContextVar._cache.items()}, loc, head_sha, getenv("GITHUB_RUN_ID") or None) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 9345ee87df..2bcb724a01 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -12,8 +12,7 @@ from test.external.process_replay.helpers import print_diff PAGE_SIZE = 100 REF = os.getenv("GITHUB_REF_NAME", "") MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20) -RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD") -TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}" +TABLE_NAME = f"process_replay_{VERSION}" os.environ["RUN_PROCESS_REPLAY"] = "0" early_stop = multiprocessing.Event() logging.basicConfig(level=logging.INFO, format="%(message)s") @@ -21,6 +20,7 @@ logging.basicConfig(level=logging.INFO, format="%(message)s") # user config ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k))) SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "") +COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", 1) if REF == "master": SKIP_PROCESS_REPLAY = True # *** differs @@ -43,7 +43,7 @@ def diff_kernel(offset:int) -> bool: if early_stop.is_set(): return True conn = db_connection() cur = conn.cursor() - cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) + cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)) changed = 0 for row in cur.fetchall(): # try unpickle @@ -54,7 +54,7 @@ def diff_kernel(offset:int) -> bool: continue # try linearize try: - with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}): + with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}): k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) # NOTE: replay with the captured renderer, not the one in master @@ -72,6 +72,7 @@ def diff_kernel(offset:int) -> bool: logging.info("PROCESS REPLAY DETECTED CHANGE") logging.info(ast) logging.info(applied_opts) + logging.info(ctx.loc) print_diff(good_src, compare_src) if ASSERT_DIFF: return True if changed > MAX_DIFF_PCT: @@ -112,9 +113,9 @@ def process_replay_schedule() -> None: def process_replay_kernel() -> None: conn = db_connection() cur = conn.cursor() - try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] + try: row_count = cur.execute(f"select count(*) from 'kernel_{TABLE_NAME}'").fetchone()[0] except sqlite3.OperationalError: - logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") + logging.warning(f"kernel_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?") return None conn.commit() cur.close() @@ -127,11 +128,12 @@ if __name__ == "__main__": logging.info("skipping process replay.") exit(0) - logging.info("***** schedule diff") - try: process_replay_schedule() - except Exception as e: - if ASSERT_DIFF: raise e - logging.error(f"schedule diff err {e}") + if COMPARE_SCHEDULE: + logging.info("***** schedule diff") + try: process_replay_schedule() + except Exception as e: + if ASSERT_DIFF: raise e + logging.error(f"schedule diff err {e}") logging.info("***** kernel diff") try: process_replay_kernel() diff --git a/test/external/process_replay/reset.py b/test/external/process_replay/reset.py index 552cec748c..9f1913bd5d 100755 --- a/test/external/process_replay/reset.py +++ b/test/external/process_replay/reset.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -from tinygrad.helpers import db_connection, VERSION, getenv, os +from tinygrad.helpers import db_connection, VERSION, os cur = db_connection() -cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}") -cur.execute(f"drop table if exists schedule_diff_{VERSION}") +cur.execute(f"drop table if exists kernel_process_replay_{VERSION}") +cur.execute(f"drop table if exists schedule_process_replay_{VERSION}") if os.path.exists(fp:=__file__.replace("reset", "master_schedule")): os.system(f"rm -rf {fp}") diff --git a/test/external/process_replay/restore.py b/test/external/process_replay/restore.py deleted file mode 100644 index 16079c5f2b..0000000000 --- a/test/external/process_replay/restore.py +++ /dev/null @@ -1,22 +0,0 @@ -# restore a specific benchmark process replay -import pickle, os -from tinygrad.device import Device -from tinygrad.helpers import db_connection, VERSION, tqdm - -cur = db_connection() -RUN_ID = os.environ["GITHUB_RUN_ID"] -ATTEMPT = os.environ["GITHUB_RUN_ATTEMPT"] -TABLE_NAME = f"process_replay_{RUN_ID}_{ATTEMPT}_{VERSION}" -PAGE_SIZE = 100 -row_cnt = cur.execute(f"select count(*) from {TABLE_NAME}").fetchone()[0] -for offset in tqdm(range(0, row_cnt, PAGE_SIZE)): - rows = cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall() - for row in rows: - ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0]) - try: Device[opts.device].compiler.compile(compare_src) - except Exception: - print("FAILED TO COMPILE") - print(ast) - print(applied_opts) - print(compare_src) - continue diff --git a/test/external/process_replay/test_process_replay.py b/test/external/process_replay/test_process_replay.py index 20c4bb0e8f..2490289981 100644 --- a/test/external/process_replay/test_process_replay.py +++ b/test/external/process_replay/test_process_replay.py @@ -1,4 +1,6 @@ import unittest +import contextlib, sqlite3 +from test.external.process_replay.helpers import ProcessReplayContext from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel from tinygrad.codegen.kernel import Kernel @@ -8,16 +10,17 @@ from tinygrad.renderer.cstyle import ClangRenderer from tinygrad.tensor import Tensor def helper_append_replay(ast:UOp, name:str, src:str) -> int: - diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {})) + name = f"kernel_{TABLE_NAME}" + diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({}))) conn = db_connection() - row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0] + row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0] return row_count class TestProcessReplay(unittest.TestCase): def tearDown(self): conn = db_connection() cur = conn.cursor() - cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'") + with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'") conn.commit() cur.close() diff --git a/test/helpers.py b/test/helpers.py index b441359cd1..ab313d42a3 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -65,7 +65,7 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None: def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp: if st_src is None: st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),) - return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype)) + return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0)) T = TypeVar("T") def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]: diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 00ef6a5b9f..db8aa86a67 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -876,7 +876,7 @@ class TestLinearizer(unittest.TestCase): lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0] ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE] # LOAD -> RANGE -> LOAD -> ASSIGN - assert lin.uops[ranges[0]-2].op is UOps.LOAD + assert len([x for x in lin.uops[:ranges[0]] if x.op is UOps.LOAD]) == 1 def test_range_outer_op_before_phi_nested_range(self): a = Tensor.randn(2, ).realize() @@ -1715,7 +1715,7 @@ class TestHandCodedOpts(unittest.TestCase): k = Kernel(si.ast) k.hand_coded_optimizations() if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) - if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end) + if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end) upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len])) assert len(upcasts) == 3 # 3 transformation matrices assert len(wino_schedule) <= 4 # 4 kernels diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b0e9fcca96..abccdfe335 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase): assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.ADD assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0] - assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1 + assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.MUL - assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2 + assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2 assert ast.src[2].src[1].op is UOps.LOAD t = t + t.full_like(3) for si in t.schedule(): @@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase): assert ast.op is UOps.STORE assert ast.src[2].arg is BinaryOps.ADD assert ast.src[2].src[0].op is UOps.LOAD - assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3 + assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3 def test_shard_memory(self): devices = (d0, d1, d2, d3) diff --git a/test/test_ops.py b/test/test_ops.py index 742d40e058..d9713781f3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -889,6 +889,8 @@ class TestOps(unittest.TestCase): def test_sum_simple(self): helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) + # NOTE: simple test for locals + # FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full def test_sum_full(self): helper_test_op([(16384)], lambda x: x.sum()) def test_sum_relu(self): diff --git a/test/test_pickle.py b/test/test_pickle.py index 050cb4cef3..00d5fc5d3e 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -2,7 +2,12 @@ import unittest, pickle import numpy as np from test.helpers import assert_equiv_uops from tinygrad import Tensor, TinyJit, Variable +from tinygrad.codegen.kernel import Kernel +from tinygrad.dtype import PtrDType, dtypes from tinygrad.engine.schedule import create_schedule +from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View class TestPickle(unittest.TestCase): def test_pickle_realized_tensor(self): @@ -66,6 +71,32 @@ class TestPickle(unittest.TestCase): sched_pk = pickle.loads(pk) assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast) + def test_pickle_define_var(self): + ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( + UOp(UOps.STORE, dtypes.void, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 + UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( + UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(UOps.LOAD, dtypes.float, arg=None, src=( + UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501 + UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( + UOp(UOps.CAST, dtypes.float, arg=None, src=( + UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( + UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=( + x2,)), + UOp.define_var("i", dtypes.int, 1, 10), + x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)), + UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=( + x12, + UOp(UOps.CONST, dtypes.int, arg=3, src=()), + x14,)),)),)),)),)),)),)) + p = Kernel(ast).to_program(name_override="test") + ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test") + self.assertEqual(ps.src, p.src) + class TestPickleJIT(unittest.TestCase): @classmethod def setUpClass(cls): diff --git a/test/test_randomness.py b/test/test_randomness.py index 159584a92a..7fe5e15677 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase): N = 128 x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16) assert x.dtype == dtypes.bfloat16 - # TODO: fix this property for bfloat16 random - # x = x.numpy() - # ones = np.take(x, np.where(x == 1)) - # zeros = np.take(x, np.where(x == 0)) - # self.assertTrue(ones.size == 0) - # self.assertTrue(zeros.size > 0) + if THREEFRY.value: + nx = x.numpy() + assert nx[nx == 1].size == 0 + assert nx[nx == 0].size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) def test_randn(self): diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 45e5b07012..51570b21a8 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -162,10 +162,10 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].arg, 3.0) def test_consts_go_last(self): - a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) - b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) - c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) - d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + a = UOp.define_var('a', dtypes.int, 0, 1) + b = UOp.define_var('b', dtypes.int, 0, 1) + c = UOp.define_var('c', dtypes.int, 0, 1) + d = UOp.define_var('d', dtypes.int, 0, 1) outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, constant_folder) @@ -186,7 +186,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.arg, 3.0) def test_where_same_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + v = UOp.define_var('tmp', dtypes.int, 0, 1) c0 = UOp(UOps.CONST, dtypes.int, arg=0) vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) @@ -280,7 +280,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -289,7 +289,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -356,7 +356,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) + v = UOp.define_var("tmp", dtypes.int, 0, 1) c2 = UOp(UOps.CONST, dtypes.int, arg=2) c4 = UOp(UOps.CONST, dtypes.int, arg=4) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) @@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase): def test_fold_gated_load_local(self): glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1)) + smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) barrier = UOp(UOps.BARRIER, dtypes.void, (st, )) @@ -610,8 +610,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_dont_fold_different_gated(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) - gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) + gate = UOp.define_var("g1", dtypes.bool, False, True) + gate2 = UOp.define_var("g2", dtypes.bool, False, True) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) @@ -626,7 +626,7 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_fold_gate(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) + gate = UOp.define_var("g1", dtypes.bool, False, True) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) @@ -637,8 +637,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_dont_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) - gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True))) + gate = UOp.define_var("g1", dtypes.bool, False, True) + gate2 = UOp.define_var("g2", dtypes.bool, False, True) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) @@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander + class TestIFUOps(unittest.TestCase): def test_create_ifs(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4)) + sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) @@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase): def test_expand_ifs_one_gate(self): gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16)) + sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) gate = valid&(lidx.ne(2)) diff --git a/test/test_uops.py b/test/test_uops.py index c3c547ba8e..b191fac730 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16)) st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st,)) sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) @@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16)) st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2)) diff --git a/test/unit/test_image_valid.py b/test/unit/test_image_valid.py index 2111a125d6..1637fd296e 100644 --- a/test/unit/test_image_valid.py +++ b/test/unit/test_image_valid.py @@ -19,7 +19,9 @@ def render(image_shape, valid:UOp, idx:UOp) -> str: return fxn.split("float4 val0 = ")[1].split(";")[0] def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) -def Variable(expr, nmin, nmax): return UOp(UOps.DEFINE_VAR, dtypes.int, (), (expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax))) +def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax) +def Range(n, nmax): + return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),)) class TestHelpers(unittest.TestCase): def test_is_increasing(self): @@ -97,46 +99,69 @@ class TestValidSimplification(unittest.TestCase): # empty self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)) - @unittest.expectedFailure # TODO: FIXME def test_openpilot_conv1(self): # first conv in openpilot # kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6 idx1 = Special("idx1", 32) idx2 = Special("idx2", 64) - ridx0 = Variable("ridx0", 0, 5) - ridx1 = Variable("ridx1", 0, 2) - ridx2 = Variable("ridx2", 0, 2) + # ridx0 = Variable("ridx0", 0, 5) + # ridx1 = Variable("ridx1", 0, 2) + # ridx2 = Variable("ridx2", 0, 2) + ridx0 = Range(0, 6) + ridx1 = Range(1, 3) + ridx2 = Range(2, 3) alu1 = ((idx2*2)+ridx1) alu4 = ((idx1*48)+(ridx2*6)+ridx0) - valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0)) + valid = (((idx2*2)+(ridx1)).lt(1).ne(True))&(((idx1*8)+(ridx2)).lt(1).ne(True)) shape = (128, 1536, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) - # (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp, - # (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f)) self.assertEqual(render(shape, valid, idx), - "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))") + "read_imagef(data0, smp, (int2)(((idx1*48)+(ridx2*6)+ridx0+(-6)),((idx2*2)+ridx1+(-1))))") - @unittest.expectedFailure # TODO: FIXME def test_openpilot_conv2(self): # conv in test/external/external_test_valid_remove.py idx1 = Special("idx1", 32) idx2 = Special("idx2", 64) - ridx0 = Variable("ridx0", 0, 2) - ridx1 = Variable("ridx1", 0, 2) - ridx2 = Variable("ridx2", 0, 2) + # ridx0 = Variable("ridx0", 0, 2) + # ridx1 = Variable("ridx1", 0, 2) + # ridx2 = Variable("ridx2", 0, 2) + ridx0 = Range(0, 3) + ridx1 = Range(1, 3) + ridx2 = Range(2, 3) alu1 = ((idx2*2)+ridx1) alu3 = ((idx1*24)+(ridx2*3)+ridx0) - valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0)) + valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True)) shape = (128, 768, 4) idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2))) self.assertEqual(render(shape, valid, idx), - "read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))") + "read_imagef(data0, smp, (int2)(((idx1*24)+(ridx2*3)+ridx0+(-3)),((idx2*2)+ridx1+(-1))))") + + def test_openpilot_conv3(self): + # in openpilot 0.9.7 + idx0 = Special("idx0", 64) + idx1 = Special("idx1", 2) + idx2 = Special("idx2", 4) + ridx0 = Range(0, 7) + ridx1 = Range(1, 7) + + alu2 = ((idx2*2)+ridx0) + alu4 = ((idx1*8)+ridx1) + alu6 = ((idx1*512)+(ridx1*64)+idx0) + + valid = alu2.lt(11)&(alu4.lt(3).ne(True)) + shape = (8, 1024, 4) + idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)/8)+1)/2)+(-4)))) + + # TODO: simplify idx + # alu0 = ((idx2*2)+ridx0) + self.assertEqual(render(shape, valid, idx), + "(((alu0<11)&((((idx1*8)+ridx1)<3)!=1))?read_imagef(data0, smp, (int2)((((idx1*512)+(ridx1*64)+idx0+832)%1024),(alu0+(-4)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501 def test_simplify1(self): # idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1) @@ -188,5 +213,20 @@ class TestValidSimplification(unittest.TestCase): self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))), "((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))") + def test_simplify5(self): + # openpilot 0.9.7, chunk replacement to simplify + shape = (10, 384, 4) + idx0 = Special("idx0", 16) + idx1 = Special("idx1", 24) + alu0 = idx0*4 + alu1 = (idx1*256)+alu0 + alu2 = idx1//3 + alu3 = ((alu1+1)%768) + idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10)) + valid = alu3.lt(640) + + self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)), + "((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))") + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_pattern_matcher.py b/test/unit/test_pattern_matcher.py index dedddac9fc..1354a0cb34 100644 --- a/test/unit/test_pattern_matcher.py +++ b/test/unit/test_pattern_matcher.py @@ -11,6 +11,20 @@ class TestPatternMatcher(unittest.TestCase): self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c2), None) + def test_match_sz_0(self): + match_cnt = 0 + def fxn(x): + nonlocal match_cnt + match_cnt += 1 + assert len(x.src) == 0 + return UOp(UOps.CONST, src=(UOp(UOps.CONST),)) + matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)]) + c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) + # second rewrite shouldn't match anything + c1 = matcher.rewrite(c1) + c1 = matcher.rewrite(c1) + self.assertEqual(match_cnt, 1) + def test_uop(self): matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)]) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 27f897b621..5d3d6143a0 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -118,7 +118,9 @@ class TestRealDoesntSimplify(unittest.TestCase): self.assertEqual(self.st.real_strides(), (None, 18, -3, -1)) class TestRealStrides(unittest.TestCase): + @unittest.expectedFailure def test_1(self): + # TODO: find the correct rewrite rule to fix this self.st = ShapeTracker(( View.create((2048,), (1,), 0, ((0, 512),)), View.create((16, 32, 4), (128, 4, 1), 0, None))) @@ -489,6 +491,14 @@ class TestComplexShapeTracker(unittest.TestCase): print(self.st.views) assert self.st.contiguous +class TestShapeTrackerEquality(unittest.TestCase): + def test_simple_equals(self): + self.assertEqual(ShapeTracker.from_shape((10,10)), ShapeTracker.from_shape((10,10))) + def test_other_equals(self): + st1 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True))) + st2 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True))) + self.assertEqual(st1, st2) + class TestSingleShapeTracker(unittest.TestCase): def setUp(self): self.st = CheckingShapeTracker((7,4)) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 34da16f6fe..ad22261247 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -27,9 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]: def NumNode(val): return UOp.const(dtypes.int, val) def Variable(expr, nmin, nmax): - vmin = UOp.const(dtypes.int, nmin) - vmax = UOp.const(dtypes.int, nmax) if isinstance(nmax, int) else nmax - return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, vmin, vmax)) + return UOp.define_var(expr, dtypes.int, nmin, nmax if isinstance(nmax, int) else nmax.arg) class Node: @staticmethod def sum(ops): return functools.reduce(lambda x,y: x+y, ops) @@ -450,6 +448,19 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)") self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))") + def test_simplex_lt(self): + a = Variable("a", 0, 3) + b = Variable("b", 0, 3) + c = Variable("c", 0, 3) + d = Variable("d", -3, 3) + self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)") + self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)") + self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)") + self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified + self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified + self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") + self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") + @unittest.skip("not supported on uops yet") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index 26ab85e6b6..3b7ab3ce1b 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -13,7 +13,7 @@ from tinygrad.shape.view import View class InvalidASTException(Exception): pass def helper_test_verify_ast(*stores:UOp) -> Kernel: - sink = UOp(UOps.SINK, None, stores) + sink = UOp(UOps.SINK, dtypes.void, stores) if DEBUG >= 3: for op in stores: print(op) try: verify_ast(sink) @@ -50,7 +50,7 @@ class TestVerifyAST(unittest.TestCase): bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) - st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) + st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_shrink_ok(self): @@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase): uop_sts = verify_ast(a.schedule()[-1].ast) store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0] self.assertEqual(store_st, ShapeTracker.from_shape((4, 4))) - const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0] + const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0] self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4))) if __name__ == '__main__': diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index b4357b9cf3..597726b160 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -4,12 +4,13 @@ from dataclasses import dataclass from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict -from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify +from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, \ + graph_rewrite, PatternMatcher from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.dtype import ImageDType, PtrDType from tinygrad.helpers import _CURRENT_KERNEL, all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, AMX, round_up, all_int, \ - get_contraction, to_function_name, diskcache_put, ContextVar + get_contraction, to_function_name, diskcache_put from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, sint from tinygrad.shape.view import strides_for_shape @@ -640,7 +641,7 @@ class Kernel: # for locals, we use the ShapeTracker that's in the srcs st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() - if op.op is UOps.CONST: return op.replace(src=(st_uop,)) + if op.op is UOps.VALID: return op.replace(src=(st_uop,)) if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) if op.op is UOps.REDUCE_AXIS: @@ -702,7 +703,7 @@ class Kernel: st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() - membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) else: @@ -711,7 +712,14 @@ class Kernel: # MUL/SUM instead of WMMA ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) else: - ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg) + # real WMMA, use CONTRACT/EXPAND to get the vectorization right + wmma_upcast_axes = wmma_arg[-2] + wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes] + wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=( + UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]), + UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]), + UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg) + ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2]) new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes) return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret if self.group_for_reduces: @@ -725,7 +733,7 @@ class Kernel: for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) + local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) if op is self.reduceops[-1]: return grouped_reduce @@ -735,7 +743,8 @@ class Kernel: elif op.op is UOps.SINK: arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals) return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg) - return fixup_ast(self.ast) + # NOTE: rewrite with an empty PatternMatcher to dedup UOps + return graph_rewrite(fixup_ast(self.ast), PatternMatcher([])) # **** this is the lowerer **** @@ -763,8 +772,8 @@ class Kernel: src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops) if getenv("RUN_PROCESS_REPLAY"): - table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}" - diskcache_put(table_name, str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()})) + from test.external.process_replay.helpers import get_process_replay_ctx + diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx())) # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes # TODO: these max and min don't work on symbolic, and results are very wrong. @@ -778,30 +787,28 @@ class Kernel: def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None: if not uop.has_st or uop in sts: return - op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg # restore globals from the two stage reduce - if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL: - _assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts) + if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL: + _assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts) sts[uop] = sts[local_reduce] return - for x in src: _assert_valid_uop(x, st, sts) + for x in uop.src: _assert_valid_uop(x, st, sts) # only reduceuop is allowed to change shape, limited to turning n to 1 - if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1])) - elif op is UOps.SWIZZLE: st = arg + if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1])) + # movementops are pushed to SHAPETRACKER and SWIZZLE + elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg + # everything else inherits shape else: - assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" - # movementops are pushed to the edges with SHAPETRACKER - # elementwise inherits shape - st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]] - for x in src: - if x.has_st and sts[x].shape != st.shape: - if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}") - raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}") + assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, UOps.ASSIGN, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}" + st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0] + if not all_same(shapes:=[x.shape for x in src_sts]): + if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}") + raise AssertionError(f"found implicit expand {sizes}") sts[uop] = st def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK" - assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size" + assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size" sts: Dict[UOp, ShapeTracker] = {} for out in ast.src: _assert_valid_uop(out, out.st_arg, sts) shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 3494286e9d..5d3c52ee9c 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -1,13 +1,17 @@ +# the job of the lowerer is to do indexing from __future__ import annotations import functools -from typing import List, Tuple, cast, Optional, Dict +from dataclasses import dataclass +from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop from tinygrad.shape.symbolic import sint from tinygrad.dtype import dtypes -from tinygrad.ops import KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps +from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten +# ***** indexing ***** + def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]): # TODO: symbolic shape if not all_int(dims): return dims @@ -34,99 +38,91 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int idx //= dims[c] return ret[::-1] if reverse else ret -class IndependentLowerer: - def lower(self, ast:UOp, opts:Renderer) -> UOp: - self.output_count = len(ast.src) +@dataclass(frozen=True) +class IndexContext: + idxs: List[UOp] + ridxs: List[UOp] - ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo() - # NOTE: assumes the shape is - full_shape = ast.full_shape - first_upcasted = len(full_shape)-ki.upcasted - first_output_st: ShapeTracker = ast.src[0].st_arg - # if there's no reduce, this is first_upcasted - first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True) - local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL] - # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces - group_for_reduces = sum([any(j!=y for j in x) for x,y in zip( - [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)], - first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0 - global_dims = first_reduce-ki.local_dims +def get_index(ast:UOp, opts:Renderer) -> IndexContext: + ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo() + # NOTE: assumes the shape is + full_shape = ast.full_shape + first_upcasted = len(full_shape)-ki.upcasted + first_output_st: ShapeTracker = ast.src[0].st_arg + # if there's no reduce, this is first_upcasted + first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True) + local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL] + # NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces + group_for_reduces = sum([any(j!=y for j in x) for x,y in zip( + [[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)], + first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0 + global_dims = first_reduce-ki.local_dims - if opts.has_local: - if ki.dont_use_locals: - assert ki.local_dims == 0, "can't use locals if there's no local dims" - self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True) - else: - # define indexes for GPU-like execution - self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \ - get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) + if opts.has_local: + if ki.dont_use_locals: + assert ki.local_dims == 0, "can't use locals if there's no local dims" + idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True) else: - # all loops are RANGES - self.idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False)) - for i,g in enumerate(full_shape[:first_reduce])] + # define indexes for GPU-like execution + idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \ + get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) + else: + # all loops are RANGES + idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False)) + for i,g in enumerate(full_shape[:first_reduce])] - # reduce loops - self.idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True)) - for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] + # reduce loops + idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True)) + for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] - # upcast loops - for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): - assert isinstance(g, int), "needs to be int to upcast/unroll" - self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),))) + # upcast loops + for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): + assert isinstance(g, int), "needs to be int to upcast/unroll" + idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),))) - # late indexes (group for reduce) - self.ridxs = self.idxs[:] - for a in range(first_reduce, first_reduce+group_for_reduces): - self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + # late indexes (group for reduce) + ridxs = idxs[:] + for a in range(first_reduce, first_reduce+group_for_reduces): + ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) - self.uop_cache: Dict[UOp, UOp] = {} - return self.to_uop(ast) + return IndexContext(idxs, ridxs) - def to_uop(self, x:UOp) -> UOp: - if uop:=self.uop_cache.get(x, None): return uop - ret = self._to_uop(x) - self.uop_cache[x] = ret - return ret +# ***** lowering (given index) ***** - def _to_uop(self, x:UOp) -> UOp: - if x.op in BUFFER_UOPS: - idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs) - # TODO: check has_valid in UPat, not here - has_valid = valid.op is not UOps.CONST or valid.arg is not True - if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0)) - buf = x.src[0] - if x.op is UOps.LOAD: - barrier = (UOp(UOps.BARRIER, dtypes.void, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else () - return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) - # NOTE: only store the local reduceop in the threads that are actually doing the reduce - store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \ - x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL - # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes - if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)]) - if x.src[0].op is UOps.DEFINE_GLOBAL or store_back: - for oidx, ridx in zip(self.idxs, self.ridxs): - if oidx != ridx: valid = valid * oidx.eq(0) - has_valid = valid.op is not UOps.CONST or valid.arg is not True - return UOp(UOps.STORE, dtypes.void, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ())) +def lower_reduce_axis(ctx: IndexContext, x: UOp): + # NOTE: always using ridxs is fine here + reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE) + alu_op: BinaryOps = x.arg[0] + ret = x.src[0] + if len(contract_axis:=flatten(x.arg for x in reduce_expand)): + ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) + ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]) + return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret - in_uops = tuple(self.to_uop(y) for y in x.src) - if x.op is UOps.WMMA: - upcast_axes = x.arg[-2] - wmma_sz = [prod(x[1] for x in l) for l in upcast_axes] - ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=( - UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]), - UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]), - UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg) - return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2]) - if x.op is UOps.REDUCE_AXIS: - # NOTE: always using ridxs is fine here - reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE) - alu_op: BinaryOps = x.arg[0] - ret = in_uops[0] - if len(contract_axis:=flatten(x.arg for x in reduce_expand)): - ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis)) - ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)]) - return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret - return x if x.src == in_uops else UOp(x.op, x.dtype, in_uops, x.arg) +def lower_load_store(ctx: IndexContext, x: UOp): + idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs) + # TODO: check has_valid in UPat, not here + has_valid = valid.op is not UOps.CONST or valid.arg is not True + buf = x.src[0] + if x.op is UOps.LOAD: + barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else () + return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier) + # NOTE: only store the local reduceop in the threads that are actually doing the reduce + store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \ + x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL + # NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes + if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs]) + if x.src[0].op is UOps.DEFINE_GLOBAL or store_back: + for oidx, ridx in zip(ctx.idxs, ctx.ridxs): + if oidx != ridx: valid = valid * oidx.eq(0) + has_valid = valid.op is not UOps.CONST or valid.arg is not True + return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ())) -def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts) +pm_lowerer = PatternMatcher([ + (UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis), + (UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]), + # rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed + (UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.SHAPETRACKER)), allow_any_len=True, name="x"), lower_load_store), +]) + +def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts)) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index ce1558123d..7fc6a9afe7 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -143,7 +143,6 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo def lt_folding(x:UOp, c:int) -> Optional[UOp]: - if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1]) return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None def fold_unrolled_divs(divs:UOp): @@ -163,16 +162,31 @@ def fold_unrolled_divs(divs:UOp): # ***** image load valid simplification ***** +def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE) + +def canonicalize_simplex(X:UOp) -> Optional[UOp]: + # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. + # returns x0 + x1 + ... in such case, or None if not + changed, ret = False, [] + for u in _get_chain(X, BinaryOps.ADD): + # assumed the const is the last src of MUL + if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0: + changed = True + u = u.src[0] + if not (is_irreducible(u) and u.vmin >= 0): return None + ret.append(u) + return functools.reduce(operator.add, ret) if changed else None + def is_increasing(f:UOp): # is f a monotonically increasing function regards its input - if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True + if f.op is UOps.CONST or is_irreducible(f): return True if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def replace_uop(uop:UOp, old:UOp, new:UOp): # replace all `old` in `uop` to `new` - return new if uop is old else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg) + return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg) def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X <= c, returns X, True, c @@ -196,30 +210,50 @@ def simplify_valid_image_load(load:UOp, buf:UOp): expr, is_upper, c = parse_valid(stmt) bounds[expr][int(is_upper)] = c + # simplify idx given that valid is True for uop,v in bounds.items(): # some expr has lower bound > upper bound -> valid is an empty set if v[0] is not None and v[1] is not None and v[0] > v[1]: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False))) - bound = uop.const_like(uop.vmin if v[0] is None else v[0]), uop.const_like(uop.vmax if v[1] is None else v[1]) - new = UOp(UOps.DEFINE_VAR, uop.dtype, (), ("fake", bound[0], bound[1])) - newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop) - if newidx.key != idx.key: idx = newidx + if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): + # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output + newidxs: List[List[UOp]] = [[], []] + for variable in _get_chain(uop, BinaryOps.ADD): + new = UOp(UOps.DEFINE_VAR, variable.dtype, (), ("fake", 1, variable.vmax)) + newidx = replace_uop(graph_rewrite(replace_uop(idx, variable, new), constant_folder), new, variable) + newidxs[0].append(newidx.src[0]) + newidxs[1].append(newidx.src[1]) + + if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1])) + if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0])) + + else: + new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]) + newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop) + if newidx.key != idx.key: idx = newidx + + # can drop valid if idx is out of bound when valid is False drop_stmt = [] for stmt in _get_chain(valid, BinaryOps.AND): - X, is_upper, c = parse_valid(stmt) - if is_upper: - # X <= c, check if it's out of bound when X = c+1 - for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])): - if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c+1)), constant_folder).vmin >= b: - drop_stmt.append(stmt) - break - else: - # X >= c, check if it's negative when X = c-1 - for i in idx.src: - if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c-1)), constant_folder).vmax < 0: - drop_stmt.append(stmt) - break + X, is_upper_bound, c = parse_valid(stmt) + + # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i + if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \ + all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)): + testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx) + testidx = graph_rewrite(testidx, constant_folder) + if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0: + drop_stmt.append(stmt) + continue + + # if X <= c, check if it's out of bound when X = c+1 + # if X >= c, check if it's out of bound when X = c-1 + test_value = c + 1 if is_upper_bound else c - 1 + for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])): + if is_increasing(i): + rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder) + if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt) if drop_stmt or idx.key != start_idx.key: new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None @@ -312,6 +346,8 @@ constant_folder = PatternMatcher([ (UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)), # self ASSIGN is just self (UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x), + # ASSIGN to global is just self + (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x), # VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal (UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None), @@ -419,6 +455,9 @@ constant_folder = PatternMatcher([ # generic lt folding (UPat.var("x").lt(UPat.cvar("c", vec=False)), lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None), + # canonicalize a simplex with positive coefficients > 0 + # not x < 1 -> X > 0 + (UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None), # ** div ** # # div folding (UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c: @@ -466,6 +505,9 @@ constant_folder = PatternMatcher([ # ** move add consts to end (NOTE: this is still happening before constant folding) ** (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), (UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1), + # ** move mul consts to end (NOTE: this is still happening before constant folding) ** + (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None), + (UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1), ]) # *** uop expander *** @@ -706,8 +748,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP} range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE} - queue:List[Tuple[int, UOp]] = [] - def push(u:UOp): + # assign priorities + def get_priority(u:UOp): priority = 0 # prefer ranges that depend on the least number of independent ranges if u.op is UOps.RANGE and u.arg[1]: @@ -717,7 +759,20 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # prefer uops that are loop children else: priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss]) - heapq.heappush(queue, (priority, u)) + return priority + priorities:Dict[UOp, int] = {u:get_priority(u) for u in children} + + # prevent priority inversion + @functools.lru_cache(None) + def fix_priority(u:UOp, lowest_priority): + if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}: + priorities[u] = min(priorities[u], lowest_priority) + if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here) + for x in u.src: fix_priority(x, priorities[u]) + fix_priority(sink, 0) + + queue:List[Tuple[int, UOp]] = [] + def push(u:UOp): heapq.heappush(queue, (priorities[u], u)) for u in children: if in_degree[u] == 0: push(u) @@ -726,7 +781,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: _uops: List[UOp] = [] while queue: p,x = heapq.heappop(queue) - if DEBUG >= 7: print(f"{p:5d}",x) + if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg) if x in scope_children: scope_end[x] = x if x.op is UOps.DEFINE_ACC: idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE]) @@ -745,12 +800,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]: # sanity checks (NOTE: these can cause things to be skipped in BEAM) if not skip_check: - bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.REDUCE_AXIS, UOps.SHAPETRACKER}]) try: type_verify(_uops) assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}" - assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}" - assert not any(x.dtype == dtypes.pyint for x in _uops), "can't return UOp with pyint" # TODO: this should be enabled, and the valid clause should be removed # NOTE: multiple identical stores to DEFINE_LOCAL is okay # NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL diff --git a/tinygrad/device.py b/tinygrad/device.py index 4f0a3d7ebe..4de4d5557d 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -108,7 +108,8 @@ class Buffer: (">" if self.options is None else f" {self.options=}>") def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview: # zero copy with as_buffer (disabled by default due to use after free) - if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf) + if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None): + return self.allocator.as_buffer(self._buf) assert not force_zero_copy, "force zero copy was passed, but copy is required" return self.copyout(memoryview(bytearray(self.nbytes))) def copyin(self, mv:memoryview): diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index ca92e6adee..3440b6c7a4 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -25,17 +25,20 @@ class DType: class ImageDType(DType): shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape base: DType + local: bool = False # images are never local def scalar(self): return self.base def vec(self, sz:int): return self.base.vec(sz) def __repr__(self): return f"dtypes.{self.name}({self.shape})" # @dataclass(frozen=True, init=False, repr=False, eq=False) class PtrDType(DType): - def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) + def __init__(self, dt:DType, local=False): + self.base, self.local = dt, local + super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) def __hash__(self): return super().__hash__() def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __ne__(self, dt): return not (self == dt) - def __repr__(self): return f"PtrDType({super().__repr__()})" + def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})" class dtypes: @staticmethod diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 16527898f5..05069d5899 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -181,7 +181,7 @@ class ExecItem: lds_est = sym_infer(self.prg.lds_estimate, var_vals) mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else "" - print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 + print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501 (str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501 f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}")) self.prg.first_run = False diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7b03c2f9f9..92f0fa9650 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -110,7 +110,7 @@ reduceop_fusor = PatternMatcher([ # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes) - (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), + (UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -139,7 +139,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. val, var_val = val.unbind() var_vals[val] = var_val else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" - return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val) + return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0)) # otherwise, it's a load and we add it to the inputs if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): @@ -157,9 +157,10 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # elementwise ops pass shapetracker in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs) - if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}: + if buf.op is MetaOps.CONTIGUOUS: assert buf in outputs, f"{buf.op} must be writable" return in_uops[0] + if buf.op is MetaOps.ASSIGN: return cache.setdefault((buf, st), UOp(UOps.ASSIGN, dtype, (in_uops[1].src[0], in_uops[0]))) if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops)) if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ddca7ef386..bba1884c1f 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -103,19 +103,23 @@ class UOps(FastEnum): VALID = auto() SPECIAL = auto() NOOP = auto() - GEP = auto() - - # math ops - CAST = auto() - BITCAST = auto() - VECTORIZE = auto() - ALU = auto() REDUCE = auto() REDUCE_AXIS = auto() + + # helper ops + GEP = auto() + VECTORIZE = auto() + CAST = auto() + BITCAST = auto() + + # loads before math + LOAD = auto() + + # math ops + ALU = auto() WMMA = auto() - # memory/assignment ops - LOAD = auto() + # assignment ops STORE = auto() ASSIGN = auto() @@ -128,7 +132,7 @@ class UOps(FastEnum): ENDRANGE = auto() ENDIF = auto() -BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST} +BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID} COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR} END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)} @@ -144,7 +148,7 @@ class UOp(MathTrait): def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None): return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg) @property - def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL} + def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR} @functools.cached_property def st(self) -> Optional[ShapeTracker]: if not self.has_st: return None @@ -170,16 +174,14 @@ class UOp(MathTrait): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg # *** uop syntactic sugar @property - def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1 - @property def st_arg(self) -> ShapeTracker: assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}" - ret = self.src[self.st_loc] + ret = self.src[0 if self.op is UOps.VALID else 1] assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" return ret.arg def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs) def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st) - def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UOp.const(self.dtype, b) + def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -215,25 +217,22 @@ class UOp(MathTrait): if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @staticmethod - def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): - return UOp(UOps.DEFINE_VAR, dtype, arg=(name, UOp.const(dtype, min_val), UOp.const(dtype, max_val))) + def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @staticmethod def range(dtype:DType, start:ConstType, end:ConstType, idx:int): return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,)) - def reduce(self, op, *rng): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op) + def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op) @functools.cached_property - def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}} + def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}} @property # parents with self def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: - if self.op is UOps.SHAPETRACKER: return self.arg.shape - # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape - return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}])) + return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st])) def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) def variables(self) -> List[Variable]: st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] - return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr) + return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr) def const_factor(self) -> int: """largest known int that divides self""" if self.op is UOps.CONST: return self.arg @@ -259,8 +258,7 @@ class UOp(MathTrait): @functools.cached_property def _min_max(self) -> Tuple[ConstType, ConstType]: # NOTE: returned UOp is assumed to be CONST - if self.op is UOps.DEFINE_VAR and self.arg: - return self.arg[1].arg, self.arg[2].arg if self.arg[2].op is UOps.CONST else dtypes.max(self.dtype) + if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax if self.op is UOps.EXPAND: return min(x.vmin for x in self.src), max(x.vmax for x in self.src) # TODO: UOps.SPECIAL is UOps.DEFINE_VAR @@ -331,7 +329,7 @@ def exec_alu(op:Op, dtype:DType, operands): def uop_alu_resolve(u:UOp) -> sint: if u.op is UOps.CONST: return u.arg - if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg) + if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2]) if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src))) raise RuntimeError(f"ALU resolve fail @ {u.op}") @@ -380,8 +378,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: def get_location() -> Tuple[str, int]: frm = sys._getframe(1) - # find the real frame in the file that has the UPat - while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}): + # find the real frame in the file that has the UPat, TODO: is there a better way to do this? + while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @functools.lru_cache(None) @@ -406,7 +404,7 @@ class UPat(MathTrait): # repeat if it's a UPat elif isinstance(src, UPat): self.src = [itertools.repeat(src)] - self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src) + self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src) self.location = location or get_location() if custom_early_reject is not None: self.early_reject = custom_early_reject @@ -459,7 +457,7 @@ class UPat(MathTrait): (self.dtype is not None and uop.dtype not in self.dtype) or \ (self.arg is not None and self.arg != uop.arg) or \ (self.op is not None and uop.op not in self.op) or \ - (self.allowed_len != 0 and len(uop.src) != self.allowed_len): return [] + (self.allowed_len != -1 and len(uop.src) != self.allowed_len): return [] if self.src is None: return [store] res: List[Dict[str, UOp]] = [] for vp in self.src: @@ -488,11 +486,11 @@ class PatternMatcher: @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns) - def rewrite(self, uop:UOp) -> Optional[UOp]: + def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]): if not early_reject.issubset(ler): continue - if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match + if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret return None # *** tracking pattern matcher *** @@ -512,7 +510,7 @@ class TrackedPatternMatcher(PatternMatcher): for p,_ in self.patterns: if p not in match_stats: match_stats[p] = [0,0,0.0,0.0] - def rewrite(self, uop:UOp) -> Optional[UOp]: + def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]: ret = None ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))]) for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]): @@ -521,7 +519,7 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][2] += time.perf_counter()-st continue match_stats[p][1] += 1 - if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: + if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: match_stats[p][0] += 1 match_stats[p][2] += (et:=time.perf_counter()-st) match_stats[p][3] += et @@ -536,25 +534,26 @@ if TRACK_MATCH_STATS: import atexit, pickle @atexit.register def print_match_stats(): - ret = [0,0,0.0,0.0] - for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): - loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" - if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) - ret = [x+y for x,y in zip(ret, v)] - print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") if TRACK_MATCH_STATS >= 2: with open("/tmp/rewrites.pkl", "wb") as f: print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl") pickle.dump(contexts, f) if getenv("VIZ"): import viz.serve - viz.serve.main() + return viz.serve.main() + ret = [0,0,0.0,0.0] + for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]): + loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}" + if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable()) + ret = [x+y for x,y in zip(ret, v)] + print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL") # *** simple graph rewrite engine *** class RewriteContext: - def __init__(self, pm): + def __init__(self, pm, ctx): self.pm: PatternMatcher = pm + self.ctx = ctx self.nodes: Dict[Tuple, UOp] = {} self.replace: Dict[UOp, UOp] = {} def rewrite(self, n:UOp) -> UOp: @@ -563,33 +562,36 @@ class RewriteContext: if found := self.nodes.get(replace_source): self.replace[n] = found else: x = UOp(*replace_source) if new_src != n.src else n - self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x + self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x return found -def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: +def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp: if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get())) - return RewriteContext(pm).rewrite(sink) + return RewriteContext(pm, ctx).rewrite(sink) # ***** uop type spec ***** # this is the matcher for the final rendered UOps # matcher functions returns True or False (or None to not match) spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [ - (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))), - (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)), + (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local), + (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local), (UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True), lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype), - (UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], UOp) and isinstance(x.arg[2], UOp)), + (UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)), (UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), (UPat(UOps.SPECIAL, src=()), lambda: True), + # no pyint allowed here! + (UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False), + # TODO: confirm the args of both of these are shapetrackers (UPat(UOps.SHAPETRACKER, src=()), lambda: True), (UPat(UOps.SWIZZLE, src=(UPat(),)), lambda: True), - (UPat(UOps.CONST, name="x"), - lambda x: x.dtype == x.dtype.scalar() and (isinstance(x.arg, Variable) and x.src) or (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), + (UPat(UOps.VALID, dtypes.bool, (UPat(UOps.SHAPETRACKER),)), lambda: True), + (UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), # early LOAD has a (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER))), lambda: True), @@ -616,13 +618,13 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), - (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True), + (UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True), (UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True), - # early WMMA has 2 args, - (UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True), - # late WMMA has 3 args, + # all WMMA has 3 args, (UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), + (UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)), + (UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)), # if has a (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True), @@ -636,7 +638,7 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b # NOTE: for testing, we let sinks be anything #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), - (UPat(UOps.SINK), lambda: True), + (UPat(UOps.SINK, dtypes.void), lambda: True), # PTX LOAD/STORE (UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 8879956691..312f94b43a 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -34,7 +34,7 @@ class Program: if not self._ran_post_init and self.uops is not None: # single pass through the uops for u in self.uops: - if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)) + if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2])) if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) if u.op is UOps.SPECIAL: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ddf3d6643c..fe58b95d24 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -56,14 +56,14 @@ class CStyleLanguage(Renderer): return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) # returns a str expression of the loaded value with the output type - def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: + def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str: if isinstance(buf_dtype, ImageDType): assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16: return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})" if output_dtype.count > 1: - return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501 + return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501 return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" @@ -78,14 +78,14 @@ class CStyleLanguage(Renderer): return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" # returns a str statement that does the store - def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str: + def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str: if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}" return f"write_imagef({buf_name}, {idx}, {var_name});" if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16: return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.count > 1: - prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix + prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};" return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" @@ -124,8 +124,9 @@ class CStyleLanguage(Renderer): kk("}") elif uop is UOps.STORE: # mark DEFINE_GLOBAL buf as writable + assert isinstance(src[0].dtype, (ImageDType, PtrDType)) if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True)) - rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) + rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]])) kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store) else: if uop is UOps.RANGE: @@ -150,7 +151,7 @@ class CStyleLanguage(Renderer): bufs[u] = (args[0], (dtype,False)) r[u] = args[0] elif uop is UOps.LOAD: - val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) + val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]])) # NOTE: this relies on the load not happening if it's in the unselected branch if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype) kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};") @@ -226,6 +227,11 @@ class ClangRenderer(CStyleLanguage): class OpenCLRenderer(CStyleLanguage): device = "GPU" + code_for_op = {**CStyleLanguage().code_for_op, + #UnaryOps.SQRT: lambda x,dtype: f"native_sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"native_recip({x})", + #UnaryOps.EXP2: lambda x,dtype: f"native_exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"native_log2({x})", + UnaryOps.SIN: lambda x,dtype: f"native_sin({x})"} if getenv("NATIVE_MATH") else CStyleLanguage().code_for_op + # language options kernel_prefix = "__kernel " buffer_prefix = "__global " diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 0794da519a..f4d2fe5c47 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, ctypes, functools, mmap, struct, array, decimal, math from types import SimpleNamespace -from typing import Tuple, List, Dict, Any +from typing import Tuple, List, Any, cast from tinygrad.device import BufferOptions, HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState, hcq_command from tinygrad.runtime.autogen import kgsl, adreno, libc from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice @@ -9,6 +9,8 @@ from tinygrad.renderer.cstyle import QCOMRenderer from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, DEBUG, fromimport if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import +BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2 + def _qreg_exec(reg, __val=0, **kwargs): for k, v in kwargs.items(): __val |= (getattr(adreno, f'{reg[4:]}_{k.upper()}') if v else 0) if type(v) is bool else (v << getattr(adreno, f'{reg[4:]}_{k.upper()}__SHIFT')) @@ -123,39 +125,38 @@ class QCOMComputeQueue(HWComputeQueue): qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total)) self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT, - state_block=adreno.SB6_CS_SHADER, num_unit=prg.kernargs_alloc_size // 4), + state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4), *data64_le(args_state.ptr)) self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT, state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128), *data64_le(prg.lib_gpu.va_addr)) - self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, - qreg.a6xx_hlsq_cs_cntl(constlen=prg.kernargs_alloc_size // 4, enabled=True)) + self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, qreg.a6xx_hlsq_cs_cntl(constlen=1024 // 4, enabled=True)) self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_hw_stack_offset(prg.hw_stack_offset)) self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, qreg.a6xx_sp_cs_instrlen(prg.image_size // 4)) - if hasattr(args_state, 'samplers_ptr'): + if args_state.prg.samp_cnt > 0: self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT, - state_block=adreno.SB6_CS_TEX, num_unit=args_state.samplers_cnt), - *data64_le(args_state.samplers_ptr.va_addr)) - self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.samplers_ptr.va_addr)) + state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt), + *data64_le(args_state.ptr + args_state.prg.samp_off)) + self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off)) self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.device._border_color_base())) - if hasattr(args_state, 'descriptors_ptr'): + if args_state.prg.tex_cnt > 0: self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT, - state_block=adreno.SB6_CS_TEX, num_unit=args_state.descriptors_cnt), - *data64_le(args_state.descriptors_ptr.va_addr)) - self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.descriptors_ptr.va_addr)) + state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt), + *data64_le(args_state.ptr + args_state.prg.tex_off)) + self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off)) - if hasattr(args_state, 'ibos_ptr'): + if args_state.prg.ibo_cnt > 0: self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT, - state_block=adreno.SB6_CS_SHADER, num_unit=args_state.ibos_cnt), - *data64_le(args_state.ibos_ptr.va_addr)) - self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ibos_ptr.va_addr)) + state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt), + *data64_le(args_state.ptr + args_state.prg.ibo_off)) + self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ptr + args_state.prg.ibo_off)) self.reg(adreno.REG_A6XX_SP_CS_CONFIG, - qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.samplers_cnt, ntex=args_state.descriptors_cnt, nibo=args_state.ibos_cnt)) + qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt)) self.cmd(adreno.CP_RUN_OPENCL, 0) self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False) @@ -175,46 +176,27 @@ class QCOMComputeQueue(HWComputeQueue): class QCOMArgsState(HCQArgsState): def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): super().__init__(ptr, prg, bufs, vals=vals) - self.ibos_cnt, self.descriptors_cnt, self.samplers_cnt = 0, 0, 0 - ctypes.memset(ptr, 0, 1024) + ctypes.memset(self.ptr, 0, 1024) + + if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}') + + self.buf_info, self.args_info = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):] - if len(bufs) + len(vals) != len(prg.buffs_info): raise RuntimeError(f'incorrect args size given={len(bufs)} != want={len(prg.buffs_info)}') - self.boffs, self.aoffs = prg.buffs_info[:len(bufs)], prg.buffs_info[len(bufs):] - for i, v in enumerate(vals): self.update_var(i, v) for cnst_val, cnst_off, cnst_sz in prg.consts_info: ctypes.memmove(self.ptr + cnst_off, (ctypes.c_int8 * cnst_sz).from_buffer_copy(cnst_val.to_bytes(cnst_sz, byteorder='little')), cnst_sz) - samplers: List[Any] = [] - descriptors: List[Any] = [] - ibos: List[Any] = [] - self.i2descr: Dict[int, int] = {} - self.i2ibo: Dict[int, int] = {} - for i, b in enumerate(bufs): - if not hasattr(b, 'samplers') and not hasattr(b, 'descriptor') and not hasattr(b, 'ibo'): self.update_buffer(i, b) - elif self.boffs[i][1]: ibos, self.i2ibo = [*ibos, *getattr(b, 'ibo')], {**self.i2ibo, i: len(ibos)} - else: - samplers, descriptors = [*samplers, *getattr(b, 'samplers')], [*descriptors, *getattr(b, 'descriptor')] - self.i2descr[i] = len(descriptors) - 1 - - def alloc_tex_gpu(data, chunk_size) -> Tuple[HCQBuffer, int]: - tex_gpu = self.prg.device.allocator.alloc(len(data) * 4, BufferOptions(nolru=True, cpu_access=True)) - to_mv(tex_gpu.va_addr, len(data) * 4).cast('I')[:] = array.array('I', data) - return tex_gpu, len(data) // chunk_size - - if len(samplers): self.samplers_ptr, self.samplers_cnt = alloc_tex_gpu(samplers, 4) - if len(descriptors): self.descriptors_ptr, self.descriptors_cnt = alloc_tex_gpu(descriptors, 16) - if len(ibos): self.ibos_ptr, self.ibos_cnt = alloc_tex_gpu(ibos, 16) - - def __del__(self): - for ptr in ('samplers_ptr', 'descriptors_ptr', 'ibos_ptr'): - if hasattr(self, ptr): self.prg.device.allocator.free((x:=getattr(self, ptr)), x.size, BufferOptions(nolru=True, cpu_access=True)) + if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers) + for i, b in enumerate(cast(List[QCOMBuffer], bufs)): + if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc) + elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo) + else: self.update_buffer(i, b) + for i, v in enumerate(vals): self.update_var(i, v) def update_buffer(self, index:int, buf:HCQBuffer): - if (descr:=self.i2descr.get(index, None)) is not None: to_mv(self.descriptors_ptr.va_addr + 16 * descr + 4 * 4, 8).cast('Q')[0] = buf.va_addr - elif (ibo:=self.i2ibo.get(index, None)) is not None: to_mv(self.ibos_ptr.va_addr + 16 * ibo + 4 * 4, 8).cast('Q')[0] = buf.va_addr - else: to_mv(self.ptr + self.boffs[index][0], 8).cast('Q')[0] = buf.va_addr + if self.buf_info[index].type is not BUFTYPE_BUF: to_mv(self.ptr+self.buf_info[index].offset+0x10, 8).cast('Q')[0] = buf.va_addr + else: to_mv(self.ptr + self.buf_info[index].offset, 8).cast('Q')[0] = buf.va_addr - def update_var(self, index:int, val:int): to_mv(self.ptr + self.aoffs[index][0], 8).cast('Q')[0] = val + def update_var(self, index:int, val:int): to_mv(self.ptr + self.args_info[index].offset, 8).cast('Q')[0] = val class QCOMProgram(HCQProgram): def __init__(self, device: QCOMDevice, name: str, lib: bytes): @@ -232,7 +214,7 @@ class QCOMProgram(HCQProgram): self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128) device._ensure_stack_size(self.hw_stack_offset * 4) - super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=1024) + super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10) def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False): if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch") @@ -253,19 +235,34 @@ class QCOMProgram(HCQProgram): self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8) # Fill up constants and buffers info - self.buffs_info, self.consts_info = [], [] + self.buf_info, self.consts_info = [], [] - samplers_count = _read_lib(image_desc_off + 0xdc) - bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samplers_count - while (bdoff + 16 <= len(self.lib)): - length, _, _, offset_words = struct.unpack("I" * 4, self.lib[bdoff:bdoff+16]) + # Collect sampler info. + self.samp_cnt = _read_lib(image_desc_off + 0xdc) + assert self.samp_cnt <= 1, "Up to one sampler supported" + if self.samp_cnt: + self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode), + qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0] + + # Collect kernel arguments (buffers) info. + bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt + while bdoff + 32 <= len(self.lib): + length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32]) if length == 0: break - self.buffs_info.append((offset_words * 4, struct.unpack("I", self.lib[bdoff+0x3c:bdoff+0x40])[0] == 0x0)) + self.buf_info.append(SimpleNamespace(offset=offset_words * 4, type=typ)) bdoff += length + # Setting correct offsets to textures/ibos. + self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info) + self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt + cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off + for x in self.buf_info: + if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40 + elif x.type is BUFTYPE_TEX: x.offset, cur_tex_off = cur_tex_off, cur_tex_off + 0x40 + if _read_lib(0xb0) != 0: # check if we have constants. cdoff = _read_lib(0xac) - while (cdoff + 40 <= image_offset): + while cdoff + 40 <= image_offset: cnst, offset_words, _, is32 = struct.unpack("I", self.lib[cdoff:cdoff+4])[0], *struct.unpack("III", self.lib[cdoff+16:cdoff+28]) self.consts_info.append((cnst, offset_words * (sz_bytes:=(2 << is32)), sz_bytes)) cdoff += 40 @@ -277,6 +274,10 @@ class QCOMProgram(HCQProgram): def __del__(self): if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True)) +class QCOMBuffer(HCQBuffer): + def __init__(self, va_addr:int, size:int, desc=None, ibo=None, pitch=None, real_stride=None): + self.va_addr, self.size, self.desc, self.ibo, self.pitch, self.real_stride = va_addr, size, desc, ibo, pitch, real_stride + class QCOMAllocator(HCQAllocator): def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: if options.image is not None: @@ -286,33 +287,37 @@ class QCOMAllocator(HCQAllocator): granularity = 128 if options.image.itemsize == 4 else 256 pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0 - pitch = round_up(imgw * 4 * options.image.itemsize, 1 << pitchalign) + pitch_add + pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add - texture = self.device._gpu_alloc(pitch * round_up(imgh, 16), kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True) + texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True) # Extend HCQBuffer with texture-related info. - texture.samplers, texture.descriptor, texture.ibo = [0] * 4, [0] * 16, [0] * 16 - - # Compiled sampler (always the same in tinygrad). - texture.samplers[0] = qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode) - texture.samplers[1] = qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True) + texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16 tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT - texture.descriptor[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt) - texture.descriptor[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh) - texture.descriptor[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6) - texture.descriptor[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)] - texture.ibo = [texture.descriptor[0] & (~0xffff), *texture.descriptor[1:len(texture.descriptor)]] + texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt) + texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh) + texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6) + texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)] + texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]] return texture return self.device._gpu_alloc(size, map_to_cpu=True) - def copyin(self, dest:HCQBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes) + def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0): + while src_off < src_size: + ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size) + src_off, dest_off = src_off+src_stride, dest_off+dest_stride + + def copyin(self, dest:HCQBuffer, src:memoryview): + if hasattr(qd:=cast(QCOMBuffer, dest), 'pitch'): self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch) + else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes) def copyout(self, dest:memoryview, src:HCQBuffer): self.device.synchronize() - ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes) + if hasattr(qs:=cast(QCOMBuffer, src), 'pitch'): self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride) + else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes) def as_buffer(self, src:HCQBuffer) -> memoryview: self.device.synchronize() diff --git a/viz/index.html b/viz/index.html index 8f4f65f28f..583c5ad260 100644 --- a/viz/index.html +++ b/viz/index.html @@ -62,13 +62,11 @@ stroke-width: 1.5px; } .graph { - grid-column: span 10; + width: 70%; position: relative; } .main-container { - display: grid; - grid-template-columns: repeat(14, 1fr); - gap: 8px; + display: flex; padding: 12px; width: 100%; height: 100%; @@ -79,12 +77,16 @@ background-color: #111111; border-radius: 8px; padding: 8px; + position: relative; } .container > * + * { margin-top: 12px; } + .main-container > * + * { + margin-left: 8px; + } .kernel-list { - grid-column: span 1; + width: 10%; display: flex; flex-direction: column; overflow-y: auto; @@ -96,9 +98,18 @@ margin-top: 4px; } .metadata { - grid-column: span 3; + width: 20%; overflow-y: auto; } + #resize-handle { + position: absolute; + left: 0; + top: 0; + bottom: 0; + width: 20px; + cursor: w-resize; + background-color: transparent; + } .rewrite-list { display: flex; flex-wrap: wrap; @@ -137,12 +148,17 @@ transform: rotate(360deg); } } + .status { + color: #EC5D5E; + font-weight: bold; + }
+
@@ -150,12 +166,16 @@
diff --git a/viz/serve.py b/viz/serve.py index 55ad89f774..93c3376ada 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from __future__ import annotations from typing import Dict, List, Tuple import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib -from dataclasses import dataclass +from dataclasses import dataclass, asdict from urllib.parse import parse_qs, urlparse from http.server import HTTPServer, BaseHTTPRequestHandler from tinygrad import Device @@ -11,6 +12,37 @@ from tinygrad.engine.graph import uops_colors, word_wrap from tinygrad.engine.realize import get_runner from tinygrad.engine.schedule import full_ast_rewrite +# **** /graph - detailed UOp + rewrites + +@dataclass(frozen=True) +class UOpRet: + loc: str + graphs: List[UOp] # snapshot of the entire AST after each rewrite + diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite + extra: List[List[str]] # these become code blocks in the UI + additions: List[List[int]] + @staticmethod + def from_ctx(ctx:TrackedRewriteContext) -> UOpRet: + uops: List[UOp] = [ctx.sink] + diffs: List[Tuple[str, Tuple[str, int], List[str]]] = [] + extra: List[List[str]] = [[str(ctx.sink)]] + additions: List[List[int]] = [[]] + seen_replaces: Dict[bytes, UOp] = {} + for i, (first, rewritten, pattern) in enumerate(ctx.rewrites): + if pattern.location[0].split("/")[-1] == "ops.py": continue + # first, rewrite this UOp with the current rewrite + all the seen rewrites before this + seen_replaces[first.key] = rewritten + new_sink = replace_uop(uops[-1], {**seen_replaces}) + # sanity check + assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" + # update ret data + additions.append([id(x) for x in rewritten.sparents]) + diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) + uops.append(new_sink) + extra.append([str(new_sink)]) + return UOpRet(ctx.loc, uops, diffs, extra, additions) + def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))} + def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: assert isinstance(x, UOp) graph: Dict[int, Tuple[str, str, List[int], str, str]] = {} @@ -22,51 +54,33 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph -@dataclass(frozen=True) -class UOpRet: - loc: str - graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite - diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite - extra: List[List[str]] # these become code blocks in the UI - def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: if (found:=replaces.get(base.key)): return found new_srcs = tuple(replace_uop(x, replaces) for x in base.src) replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base return ret -def create_graph(ctx:TrackedRewriteContext) -> UOpRet: - uops: List[UOp] = [ctx.sink] - graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)] - diffs: List[Tuple[str, Tuple[str, int], List[str]]] = [] - extra: List[List[str]] = [[str(ctx.sink)]] - seen_replaces: Dict[bytes, UOp] = {} - for i, (first, rewritten, pattern) in enumerate(ctx.rewrites): - if pattern.location[0].split("/")[-1] == "ops.py": continue - # first, rewrite this UOp with the current rewrite + all the seen rewrites before this - seen_replaces[first.key] = rewritten - new_sink = replace_uop(uops[-1], {**seen_replaces}) - # sanity check - assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" - # update ret data - diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines())))) - graphs.append((new_sink, uops[-1], rewritten, first)) - uops.append(new_sink) - extra.append([str(new_sink)]) - return UOpRet(ctx.loc, graphs, diffs, extra) +# **** /kernels - Overview of the kernel -def get_ctx_groups(contexts:List[TrackedRewriteContext]) -> Dict[str, Tuple[List[TrackedRewriteContext], str]]: - ctx_groups: Dict[str, Tuple[List[TrackedRewriteContext], str]] = {} +@dataclass(frozen=True) +class KernelRet: + name: str + code: str + ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext] + def to_json(self) -> Dict: + return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]} + +def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]: + ret: Dict[str, KernelRet] = {} kernel_name = "" code = "" for ctx in contexts: if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py": with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src - elif ctx.kernel_name is not None: kernel_name = ctx.kernel_name - if ctx_groups.get(k:=to_function_name(kernel_name)) is None: ctx_groups[k] = ([], code) - # TODO: make ansi play nice with css - ctx_groups[to_function_name(kernel_name)][0].append(ctx) - return ctx_groups + elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, "" + if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {}) + ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx + return list(ret.values()) class Handler(BaseHTTPRequestHandler): def do_GET(self): @@ -87,20 +101,18 @@ class Handler(BaseHTTPRequestHandler): self.send_header("Content-type", "application/json") self.end_headers() with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) - ctx_groups = get_ctx_groups(contexts) - ret = json.dumps({k:[x.loc for x in v[0]] for k,v in ctx_groups.items()}).encode() + kernels = load_kernels(contexts) + ret = json.dumps([x.to_json() for x in kernels]).encode() elif url.path == "/graph": query = parse_qs(url.query) self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f) - ctx_groups = get_ctx_groups(contexts) - group, code = ctx_groups[list(ctx_groups.keys())[int(query["kernel_idx"][0])]] - g = create_graph(group[int(query["uop_idx"][0])]) - rest = [x.loc for x in group] - ret = json.dumps(({"loc": g.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in g.graphs], - "diffs": g.diffs, "extra": g.extra, "code": code}, rest)).encode() + kernels = load_kernels(contexts) + k = kernels[int(query["kernel_idx"][0])] + g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])]) + ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode() else: self.send_response(404) ret = b"" diff --git a/viz/test_viz.py b/viz/test_viz.py index 5b76e1c6cf..9a0f8df5dd 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -6,10 +6,10 @@ from tinygrad import Tensor from tinygrad.engine.realize import lower_schedule from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps from tinygrad.dtype import dtypes, PtrDType -from tinygrad.helpers import CI, all_same, DEBUG, colored, getenv +from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding from test.external.process_replay.helpers import print_diff -from viz.serve import create_graph, get_ctx_groups +from viz.serve import UOpRet, load_kernels class TestViz(unittest.TestCase): def tearDown(self) -> None: @@ -19,12 +19,11 @@ class TestViz(unittest.TestCase): def assert_valid_ctx(self, contexts): assert len(contexts) != 0 for i,ctx in enumerate(contexts): - try: ret = create_graph(ctx) + try: ret = UOpRet.from_ctx(ctx) except Exception as e: print(colored(f"failed to create graph for ctx {i}", "red")) raise e - rewrites = [x[0] for x in ret.graphs] - for j,(x,y) in enumerate(zip(rewrites, rewrites[1:])): + for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])): if x.key == y.key: raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}") @@ -45,10 +44,10 @@ class TestViz(unittest.TestCase): schedule2 = Tensor.randn(4, 4).contiguous().schedule() list(lower_schedule(schedule1)) list(lower_schedule(schedule2)) - ret = get_ctx_groups(contexts) - assert len(ret.keys()) == 2 - assert all(len([x for x in ctxs if "schedule" in x.loc]) != 0 for ctxs,_ in ret.values()) - assert all(len([x for x in ctxs if "uopgraph" in x.loc]) != 0 for ctxs,_ in ret.values()) + ret = load_kernels(contexts) + assert len(ret) == 2 + assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret) + assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret) def test_gemm_diff(self): x = Tensor.empty(64, 64).realize() @@ -67,8 +66,8 @@ class TestViz(unittest.TestCase): ]) ret = graph_rewrite(sink, pm) if DEBUG >= 4: print_diff(sink, ret) - g = create_graph(contexts[0]) - assert g.graphs[-1][0].key == ret.key + g = UOpRet.from_ctx(contexts[0]) + assert g.graphs[-1].key == ret.key self.assert_valid_ctx(contexts) def test_devectorize_viz(self): @@ -115,5 +114,28 @@ class TestViz(unittest.TestCase): simple_pm.rewrite(UOp.const(dtypes.int, 2)) self.assertEqual(len(contexts), 0) + def test_dedup_ast(self): + contexts.clear() + a = Tensor.randn(4, 4)+2 + b = Tensor.randn(4, 4)+2 + Tensor.schedule(a, b) + kernels = load_kernels(contexts) + self.assertEqual(len(kernels), 1) + schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"] + self.assertEqual(len(schedule_ctxs), 1) + + def test_no_dedup_different_opts(self): + contexts.clear() + a = Tensor.empty(4, 4)+Tensor.empty(4, 4) + s = a.schedule() + with Context(NOOPT=1): list(lower_schedule(s.copy())) + with Context(NOOPT=0): list(lower_schedule(s.copy())) + kernels = load_kernels(contexts) + self.assertEqual(len(kernels), 2) + schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"] + self.assertEqual(len(schedule_ctxs), 1) + schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"] + self.assertEqual(len(schedule_ctxs), 0) + if __name__ == "__main__": unittest.main()