mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
14
.github/workflows/benchmark.yml
vendored
14
.github/workflows/benchmark.yml
vendored
@@ -45,6 +45,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Run Stable Diffusion
|
||||
run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run Stable Diffusion with fp16
|
||||
@@ -148,6 +150,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Run model inference benchmark
|
||||
run: NV=1 NOCLANG=1 python3 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
@@ -252,6 +256,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Fuzz Padded Tensor Core GEMM (NV)
|
||||
run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM (PTX)
|
||||
@@ -318,6 +324,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Show off tinybox
|
||||
run: /opt/rocm/bin/rocm-bandwidth-test
|
||||
# TODO: unstable on AMD
|
||||
@@ -420,6 +428,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
@@ -471,6 +481,8 @@ jobs:
|
||||
run: |
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: openpilot compile 0.9.4
|
||||
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt
|
||||
- name: openpilot compile 0.9.7
|
||||
@@ -485,6 +497,8 @@ jobs:
|
||||
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_4.txt
|
||||
- name: benchmark openpilot w IMAGE=2 0.9.7
|
||||
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt
|
||||
- name: openpilot compile3 0.9.7
|
||||
run: PYTHONPATH="." QCOM=1 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
- uses: actions/upload-artifact@v4
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -197,7 +197,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
|
||||
@@ -6,6 +6,7 @@ sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1"
|
||||
|
||||
from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters
|
||||
from tinygrad.helpers import OSX, DEBUG, Timing
|
||||
|
||||
@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
|
||||
from examples.stable_diffusion import ResnetBlock, Mid
|
||||
import numpy as np
|
||||
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
|
||||
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
|
||||
import argparse, tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -282,19 +282,29 @@ class SDXL:
|
||||
return self.first_stage_model.decode(1.0 / 0.13025 * x)
|
||||
|
||||
|
||||
class VanillaCFG:
|
||||
class Guider(ABC):
|
||||
def __init__(self, scale:float):
|
||||
self.scale = scale
|
||||
|
||||
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
|
||||
@abstractmethod
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
pass
|
||||
|
||||
class VanillaCFG(Guider):
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
c_out = {}
|
||||
for k in c:
|
||||
assert k in ["vector", "crossattn", "concat"]
|
||||
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
|
||||
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x_u, x_c = x.chunk(2)
|
||||
x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2)
|
||||
x_pred = x_u + self.scale*(x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
class SplitVanillaCFG(Guider):
|
||||
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
|
||||
x_u = denoiser(x, s, uc)
|
||||
x_c = denoiser(x, s, c)
|
||||
x_pred = x_u + self.scale*(x_c - x_u)
|
||||
return x_pred
|
||||
|
||||
@@ -302,13 +312,12 @@ class VanillaCFG:
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
|
||||
class DPMPP2MSampler:
|
||||
def __init__(self, cfg_scale:float):
|
||||
def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG):
|
||||
self.discretization = LegacyDDPMDiscretization()
|
||||
self.guider = VanillaCFG(cfg_scale)
|
||||
self.guider = guider_cls(cfg_scale)
|
||||
|
||||
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
|
||||
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
|
||||
denoised = self.guider(denoised)
|
||||
denoised = self.guider(denoiser, x, sigma, c, uc)
|
||||
|
||||
t, t_next = sigma.log().neg(), next_sigma.log().neg()
|
||||
h = t_next - t
|
||||
@@ -329,7 +338,7 @@ class DPMPP2MSampler:
|
||||
return x, denoised
|
||||
|
||||
def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
|
||||
sigmas = self.discretization(num_steps)
|
||||
sigmas = self.discretization(num_steps).to(x.device)
|
||||
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
num_sigmas = len(sigmas)
|
||||
|
||||
|
||||
Binary file not shown.
@@ -7,7 +7,7 @@ import math
|
||||
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
|
||||
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
|
||||
if len(inp.raw_data) > 0:
|
||||
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
|
||||
return Tensor(data, requires_grad=False).reshape(tuple(inp.dims))
|
||||
return Tensor(data.reshape(tuple(inp.dims)), requires_grad=False)
|
||||
return Tensor(None, requires_grad=False)
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
|
||||
3
test/external/external_benchmark_schedule.py
vendored
3
test/external/external_benchmark_schedule.py
vendored
@@ -1,7 +1,7 @@
|
||||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
|
||||
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen, _CURRENT_KERNEL
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.lowerer import ast_to_uop
|
||||
@@ -43,6 +43,7 @@ if __name__ == "__main__":
|
||||
rewritten_uops = []
|
||||
for i,(k,u) in enumerate(zip(kernels, uops)):
|
||||
with Timing(f"rewrite {i:2d} {k.name}{' '*(50-ansilen(k.name))}", enabled=getenv("VERBOSE", 0)):
|
||||
if getenv("VIZ"): _CURRENT_KERNEL.set(k.name)
|
||||
rewritten_uops.append(full_graph_rewrite(u, k.opts))
|
||||
uops = rewritten_uops
|
||||
if getenv("LINEARIZE", 1):
|
||||
|
||||
Binary file not shown.
@@ -17,7 +17,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List
|
||||
if not os.path.isfile(fp):
|
||||
shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp)
|
||||
# create the reference graph
|
||||
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs)
|
||||
ref_graph, ref_in_degree, _ = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs)
|
||||
# compare
|
||||
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])
|
||||
|
||||
@@ -26,7 +26,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
|
||||
for _,in_degree in s:
|
||||
for lsi in in_degree:
|
||||
for buf in lsi.outputs:
|
||||
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
||||
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), tuple(lsi.metadata)))
|
||||
changed = 0
|
||||
seen_diffs: Set[bytes] = set()
|
||||
for buf,si in si_for_buf.items():
|
||||
|
||||
20
test/external/process_replay/helpers.py
vendored
20
test/external/process_replay/helpers.py
vendored
@@ -1,5 +1,7 @@
|
||||
import difflib, logging
|
||||
from tinygrad.helpers import colored, getenv
|
||||
from dataclasses import dataclass
|
||||
import difflib, logging, traceback, subprocess
|
||||
from typing import Dict, Optional
|
||||
from tinygrad.helpers import ContextVar, colored, getenv
|
||||
|
||||
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
@@ -10,3 +12,17 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
|
||||
import ocdiff
|
||||
diff = ocdiff.console_diff(str(s0), str(s1))
|
||||
logging.info(diff)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProcessReplayContext:
|
||||
ctx_vars: Dict[str, int]
|
||||
loc: str = ""
|
||||
head_sha: str = ""
|
||||
run_id: Optional[int] = None
|
||||
def get_process_replay_ctx() -> ProcessReplayContext:
|
||||
stack = filter(lambda x: "tinygrad" in x.filename and not any(n in x.filename for n in ["engine/schedule.py", "engine/realize.py", \
|
||||
"codegen/kernel.py", "unittest"]), traceback.extract_stack()[:-1])
|
||||
loc = "\n".join(traceback.format_list(stack))
|
||||
try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
|
||||
except Exception: head_sha = ""
|
||||
return ProcessReplayContext({k:v.value for k,v in ContextVar._cache.items()}, loc, head_sha, getenv("GITHUB_RUN_ID") or None)
|
||||
|
||||
24
test/external/process_replay/process_replay.py
vendored
24
test/external/process_replay/process_replay.py
vendored
@@ -12,8 +12,7 @@ from test.external.process_replay.helpers import print_diff
|
||||
PAGE_SIZE = 100
|
||||
REF = os.getenv("GITHUB_REF_NAME", "")
|
||||
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
|
||||
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
|
||||
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
|
||||
TABLE_NAME = f"process_replay_{VERSION}"
|
||||
os.environ["RUN_PROCESS_REPLAY"] = "0"
|
||||
early_stop = multiprocessing.Event()
|
||||
logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
@@ -21,6 +20,7 @@ logging.basicConfig(level=logging.INFO, format="%(message)s")
|
||||
# user config
|
||||
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
|
||||
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
|
||||
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", 1)
|
||||
if REF == "master": SKIP_PROCESS_REPLAY = True
|
||||
|
||||
# *** differs
|
||||
@@ -43,7 +43,7 @@ def diff_kernel(offset:int) -> bool:
|
||||
if early_stop.is_set(): return True
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||
cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
|
||||
changed = 0
|
||||
for row in cur.fetchall():
|
||||
# try unpickle
|
||||
@@ -54,7 +54,7 @@ def diff_kernel(offset:int) -> bool:
|
||||
continue
|
||||
# try linearize
|
||||
try:
|
||||
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
# NOTE: replay with the captured renderer, not the one in master
|
||||
@@ -72,6 +72,7 @@ def diff_kernel(offset:int) -> bool:
|
||||
logging.info("PROCESS REPLAY DETECTED CHANGE")
|
||||
logging.info(ast)
|
||||
logging.info(applied_opts)
|
||||
logging.info(ctx.loc)
|
||||
print_diff(good_src, compare_src)
|
||||
if ASSERT_DIFF: return True
|
||||
if changed > MAX_DIFF_PCT:
|
||||
@@ -112,9 +113,9 @@ def process_replay_schedule() -> None:
|
||||
def process_replay_kernel() -> None:
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
try: row_count = cur.execute(f"select count(*) from 'kernel_{TABLE_NAME}'").fetchone()[0]
|
||||
except sqlite3.OperationalError:
|
||||
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||
logging.warning(f"kernel_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
|
||||
return None
|
||||
conn.commit()
|
||||
cur.close()
|
||||
@@ -127,11 +128,12 @@ if __name__ == "__main__":
|
||||
logging.info("skipping process replay.")
|
||||
exit(0)
|
||||
|
||||
logging.info("***** schedule diff")
|
||||
try: process_replay_schedule()
|
||||
except Exception as e:
|
||||
if ASSERT_DIFF: raise e
|
||||
logging.error(f"schedule diff err {e}")
|
||||
if COMPARE_SCHEDULE:
|
||||
logging.info("***** schedule diff")
|
||||
try: process_replay_schedule()
|
||||
except Exception as e:
|
||||
if ASSERT_DIFF: raise e
|
||||
logging.error(f"schedule diff err {e}")
|
||||
|
||||
logging.info("***** kernel diff")
|
||||
try: process_replay_kernel()
|
||||
|
||||
6
test/external/process_replay/reset.py
vendored
6
test/external/process_replay/reset.py
vendored
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
from tinygrad.helpers import db_connection, VERSION, getenv, os
|
||||
from tinygrad.helpers import db_connection, VERSION, os
|
||||
cur = db_connection()
|
||||
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
|
||||
cur.execute(f"drop table if exists schedule_diff_{VERSION}")
|
||||
cur.execute(f"drop table if exists kernel_process_replay_{VERSION}")
|
||||
cur.execute(f"drop table if exists schedule_process_replay_{VERSION}")
|
||||
if os.path.exists(fp:=__file__.replace("reset", "master_schedule")):
|
||||
os.system(f"rm -rf {fp}")
|
||||
|
||||
22
test/external/process_replay/restore.py
vendored
22
test/external/process_replay/restore.py
vendored
@@ -1,22 +0,0 @@
|
||||
# restore a specific benchmark process replay
|
||||
import pickle, os
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import db_connection, VERSION, tqdm
|
||||
|
||||
cur = db_connection()
|
||||
RUN_ID = os.environ["GITHUB_RUN_ID"]
|
||||
ATTEMPT = os.environ["GITHUB_RUN_ATTEMPT"]
|
||||
TABLE_NAME = f"process_replay_{RUN_ID}_{ATTEMPT}_{VERSION}"
|
||||
PAGE_SIZE = 100
|
||||
row_cnt = cur.execute(f"select count(*) from {TABLE_NAME}").fetchone()[0]
|
||||
for offset in tqdm(range(0, row_cnt, PAGE_SIZE)):
|
||||
rows = cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall()
|
||||
for row in rows:
|
||||
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
|
||||
try: Device[opts.device].compiler.compile(compare_src)
|
||||
except Exception:
|
||||
print("FAILED TO COMPILE")
|
||||
print(ast)
|
||||
print(applied_opts)
|
||||
print(compare_src)
|
||||
continue
|
||||
@@ -1,4 +1,6 @@
|
||||
import unittest
|
||||
import contextlib, sqlite3
|
||||
from test.external.process_replay.helpers import ProcessReplayContext
|
||||
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
|
||||
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
@@ -8,16 +10,17 @@ from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
|
||||
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
|
||||
name = f"kernel_{TABLE_NAME}"
|
||||
diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({})))
|
||||
conn = db_connection()
|
||||
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
|
||||
row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0]
|
||||
return row_count
|
||||
|
||||
class TestProcessReplay(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
|
||||
with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'")
|
||||
conn.commit()
|
||||
cur.close()
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
|
||||
if st_src is None:
|
||||
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
|
||||
return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype))
|
||||
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
|
||||
T = TypeVar("T")
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
|
||||
@@ -876,7 +876,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
|
||||
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
|
||||
# LOAD -> RANGE -> LOAD -> ASSIGN
|
||||
assert lin.uops[ranges[0]-2].op is UOps.LOAD
|
||||
assert len([x for x in lin.uops[:ranges[0]] if x.op is UOps.LOAD]) == 1
|
||||
|
||||
def test_range_outer_op_before_phi_nested_range(self):
|
||||
a = Tensor.randn(2, ).realize()
|
||||
@@ -1715,7 +1715,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
k = Kernel(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
|
||||
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
|
||||
assert len(upcasts) == 3 # 3 transformation matrices
|
||||
assert len(wino_schedule) <= 4 # 4 kernels
|
||||
|
||||
@@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
|
||||
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
|
||||
t = 2 * t
|
||||
for si in t.schedule():
|
||||
ast = si.ast.src[0]
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.MUL
|
||||
assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2
|
||||
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
|
||||
assert ast.src[2].src[1].op is UOps.LOAD
|
||||
t = t + t.full_like(3)
|
||||
for si in t.schedule():
|
||||
@@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.op is UOps.STORE
|
||||
assert ast.src[2].arg is BinaryOps.ADD
|
||||
assert ast.src[2].src[0].op is UOps.LOAD
|
||||
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3
|
||||
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
|
||||
@@ -889,6 +889,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_sum_simple(self):
|
||||
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
|
||||
# NOTE: simple test for locals
|
||||
# FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full
|
||||
def test_sum_full(self):
|
||||
helper_test_op([(16384)], lambda x: x.sum())
|
||||
def test_sum_relu(self):
|
||||
|
||||
@@ -2,7 +2,12 @@ import unittest, pickle
|
||||
import numpy as np
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import Tensor, TinyJit, Variable
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
def test_pickle_realized_tensor(self):
|
||||
@@ -66,6 +71,32 @@ class TestPickle(unittest.TestCase):
|
||||
sched_pk = pickle.loads(pk)
|
||||
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
||||
|
||||
def test_pickle_define_var(self):
|
||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
||||
x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
|
||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
|
||||
x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
|
||||
x2,)),
|
||||
UOp.define_var("i", dtypes.int, 1, 10),
|
||||
x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),
|
||||
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
|
||||
x12,
|
||||
UOp(UOps.CONST, dtypes.int, arg=3, src=()),
|
||||
x14,)),)),)),)),)),)),))
|
||||
p = Kernel(ast).to_program(name_override="test")
|
||||
ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test")
|
||||
self.assertEqual(ps.src, p.src)
|
||||
|
||||
class TestPickleJIT(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
@@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase):
|
||||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
|
||||
assert x.dtype == dtypes.bfloat16
|
||||
# TODO: fix this property for bfloat16 random
|
||||
# x = x.numpy()
|
||||
# ones = np.take(x, np.where(x == 1))
|
||||
# zeros = np.take(x, np.where(x == 0))
|
||||
# self.assertTrue(ones.size == 0)
|
||||
# self.assertTrue(zeros.size > 0)
|
||||
if THREEFRY.value:
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
def test_randn(self):
|
||||
|
||||
@@ -162,10 +162,10 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
a = UOp.define_var('a', dtypes.int, 0, 1)
|
||||
b = UOp.define_var('b', dtypes.int, 0, 1)
|
||||
c = UOp.define_var('c', dtypes.int, 0, 1)
|
||||
d = UOp.define_var('d', dtypes.int, 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
@@ -186,7 +186,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
v = UOp.define_var('tmp', dtypes.int, 0, 1)
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -280,7 +280,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -289,7 +289,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -356,7 +356,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
v = UOp.define_var("tmp", dtypes.int, 0, 1)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
@@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
|
||||
def test_fold_gated_load_local(self):
|
||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
|
||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1))
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||
@@ -610,8 +610,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -626,7 +626,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -637,8 +637,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander +
|
||||
class TestIFUOps(unittest.TestCase):
|
||||
def test_create_ifs(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||
gate = valid&(lidx.ne(2))
|
||||
@@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase):
|
||||
|
||||
def test_expand_ifs_one_gate(self):
|
||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
|
||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16))
|
||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||
gate = valid&(lidx.ne(2))
|
||||
|
||||
@@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_basic(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16))
|
||||
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
|
||||
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
|
||||
@@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||
def test_local_indirect(self):
|
||||
uops = []
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
|
||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16))
|
||||
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
||||
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
|
||||
|
||||
@@ -19,7 +19,9 @@ def render(image_shape, valid:UOp, idx:UOp) -> str:
|
||||
return fxn.split("float4 val0 = ")[1].split(";")[0]
|
||||
|
||||
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp(UOps.DEFINE_VAR, dtypes.int, (), (expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
|
||||
def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax)
|
||||
def Range(n, nmax):
|
||||
return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
class TestHelpers(unittest.TestCase):
|
||||
def test_is_increasing(self):
|
||||
@@ -97,46 +99,69 @@ class TestValidSimplification(unittest.TestCase):
|
||||
# empty
|
||||
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx))
|
||||
|
||||
@unittest.expectedFailure # TODO: FIXME
|
||||
def test_openpilot_conv1(self):
|
||||
# first conv in openpilot
|
||||
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
|
||||
idx1 = Special("idx1", 32)
|
||||
idx2 = Special("idx2", 64)
|
||||
ridx0 = Variable("ridx0", 0, 5)
|
||||
ridx1 = Variable("ridx1", 0, 2)
|
||||
ridx2 = Variable("ridx2", 0, 2)
|
||||
# ridx0 = Variable("ridx0", 0, 5)
|
||||
# ridx1 = Variable("ridx1", 0, 2)
|
||||
# ridx2 = Variable("ridx2", 0, 2)
|
||||
ridx0 = Range(0, 6)
|
||||
ridx1 = Range(1, 3)
|
||||
ridx2 = Range(2, 3)
|
||||
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
|
||||
|
||||
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
|
||||
valid = (((idx2*2)+(ridx1)).lt(1).ne(True))&(((idx1*8)+(ridx2)).lt(1).ne(True))
|
||||
shape = (128, 1536, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
# (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp,
|
||||
# (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
|
||||
"read_imagef(data0, smp, (int2)(((idx1*48)+(ridx2*6)+ridx0+(-6)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
@unittest.expectedFailure # TODO: FIXME
|
||||
def test_openpilot_conv2(self):
|
||||
# conv in test/external/external_test_valid_remove.py
|
||||
idx1 = Special("idx1", 32)
|
||||
idx2 = Special("idx2", 64)
|
||||
ridx0 = Variable("ridx0", 0, 2)
|
||||
ridx1 = Variable("ridx1", 0, 2)
|
||||
ridx2 = Variable("ridx2", 0, 2)
|
||||
# ridx0 = Variable("ridx0", 0, 2)
|
||||
# ridx1 = Variable("ridx1", 0, 2)
|
||||
# ridx2 = Variable("ridx2", 0, 2)
|
||||
ridx0 = Range(0, 3)
|
||||
ridx1 = Range(1, 3)
|
||||
ridx2 = Range(2, 3)
|
||||
|
||||
alu1 = ((idx2*2)+ridx1)
|
||||
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
|
||||
|
||||
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
|
||||
valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True))
|
||||
shape = (128, 768, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))")
|
||||
"read_imagef(data0, smp, (int2)(((idx1*24)+(ridx2*3)+ridx0+(-3)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
def test_openpilot_conv3(self):
|
||||
# in openpilot 0.9.7
|
||||
idx0 = Special("idx0", 64)
|
||||
idx1 = Special("idx1", 2)
|
||||
idx2 = Special("idx2", 4)
|
||||
ridx0 = Range(0, 7)
|
||||
ridx1 = Range(1, 7)
|
||||
|
||||
alu2 = ((idx2*2)+ridx0)
|
||||
alu4 = ((idx1*8)+ridx1)
|
||||
alu6 = ((idx1*512)+(ridx1*64)+idx0)
|
||||
|
||||
valid = alu2.lt(11)&(alu4.lt(3).ne(True))
|
||||
shape = (8, 1024, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)/8)+1)/2)+(-4))))
|
||||
|
||||
# TODO: simplify idx
|
||||
# alu0 = ((idx2*2)+ridx0)
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
"(((alu0<11)&((((idx1*8)+ridx1)<3)!=1))?read_imagef(data0, smp, (int2)((((idx1*512)+(ridx1*64)+idx0+832)%1024),(alu0+(-4)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
@@ -188,5 +213,20 @@ class TestValidSimplification(unittest.TestCase):
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
def test_simplify5(self):
|
||||
# openpilot 0.9.7, chunk replacement to simplify
|
||||
shape = (10, 384, 4)
|
||||
idx0 = Special("idx0", 16)
|
||||
idx1 = Special("idx1", 24)
|
||||
alu0 = idx0*4
|
||||
alu1 = (idx1*256)+alu0
|
||||
alu2 = idx1//3
|
||||
alu3 = ((alu1+1)%768)
|
||||
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
|
||||
valid = alu3.lt(640)
|
||||
|
||||
self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -11,6 +11,20 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
self.assertEqual(matcher.rewrite(c1), c1)
|
||||
self.assertEqual(matcher.rewrite(c2), None)
|
||||
|
||||
def test_match_sz_0(self):
|
||||
match_cnt = 0
|
||||
def fxn(x):
|
||||
nonlocal match_cnt
|
||||
match_cnt += 1
|
||||
assert len(x.src) == 0
|
||||
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
# second rewrite shouldn't match anything
|
||||
c1 = matcher.rewrite(c1)
|
||||
c1 = matcher.rewrite(c1)
|
||||
self.assertEqual(match_cnt, 1)
|
||||
|
||||
def test_uop(self):
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
||||
@@ -118,7 +118,9 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
||||
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1))
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_1(self):
|
||||
# TODO: find the correct rewrite rule to fix this
|
||||
self.st = ShapeTracker((
|
||||
View.create((2048,), (1,), 0, ((0, 512),)),
|
||||
View.create((16, 32, 4), (128, 4, 1), 0, None)))
|
||||
@@ -489,6 +491,14 @@ class TestComplexShapeTracker(unittest.TestCase):
|
||||
print(self.st.views)
|
||||
assert self.st.contiguous
|
||||
|
||||
class TestShapeTrackerEquality(unittest.TestCase):
|
||||
def test_simple_equals(self):
|
||||
self.assertEqual(ShapeTracker.from_shape((10,10)), ShapeTracker.from_shape((10,10)))
|
||||
def test_other_equals(self):
|
||||
st1 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
||||
st2 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
||||
self.assertEqual(st1, st2)
|
||||
|
||||
class TestSingleShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
||||
@@ -27,9 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
def Variable(expr, nmin, nmax):
|
||||
vmin = UOp.const(dtypes.int, nmin)
|
||||
vmax = UOp.const(dtypes.int, nmax) if isinstance(nmax, int) else nmax
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, vmin, vmax))
|
||||
return UOp.define_var(expr, dtypes.int, nmin, nmax if isinstance(nmax, int) else nmax.arg)
|
||||
class Node:
|
||||
@staticmethod
|
||||
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
@@ -450,6 +448,19 @@ class TestSymbolic(unittest.TestCase):
|
||||
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
|
||||
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))")
|
||||
|
||||
def test_simplex_lt(self):
|
||||
a = Variable("a", 0, 3)
|
||||
b = Variable("b", 0, 3)
|
||||
c = Variable("c", 0, 3)
|
||||
d = Variable("d", -3, 3)
|
||||
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
|
||||
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
|
||||
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified
|
||||
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
|
||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.shape.view import View
|
||||
|
||||
class InvalidASTException(Exception): pass
|
||||
def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||
sink = UOp(UOps.SINK, None, stores)
|
||||
sink = UOp(UOps.SINK, dtypes.void, stores)
|
||||
if DEBUG >= 3:
|
||||
for op in stores: print(op)
|
||||
try: verify_ast(sink)
|
||||
@@ -50,7 +50,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||
st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||
|
||||
def test_shrink_ok(self):
|
||||
@@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
uop_sts = verify_ast(a.schedule()[-1].ast)
|
||||
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
|
||||
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
|
||||
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
|
||||
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -4,12 +4,13 @@ from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, \
|
||||
graph_rewrite, PatternMatcher
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, AMX, round_up, all_int, \
|
||||
get_contraction, to_function_name, diskcache_put, ContextVar
|
||||
get_contraction, to_function_name, diskcache_put
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
@@ -640,7 +641,7 @@ class Kernel:
|
||||
# for locals, we use the ShapeTracker that's in the srcs
|
||||
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
|
||||
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
|
||||
if op.op is UOps.CONST: return op.replace(src=(st_uop,))
|
||||
if op.op is UOps.VALID: return op.replace(src=(st_uop,))
|
||||
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
|
||||
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
|
||||
if op.op is UOps.REDUCE_AXIS:
|
||||
@@ -702,7 +703,7 @@ class Kernel:
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||
else:
|
||||
@@ -711,7 +712,14 @@ class Kernel:
|
||||
# MUL/SUM instead of WMMA
|
||||
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
|
||||
else:
|
||||
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
|
||||
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
|
||||
wmma_upcast_axes = wmma_arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
|
||||
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
||||
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
|
||||
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
|
||||
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
|
||||
if self.group_for_reduces:
|
||||
@@ -725,7 +733,7 @@ class Kernel:
|
||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||
if op is self.reduceops[-1]: return grouped_reduce
|
||||
@@ -735,7 +743,8 @@ class Kernel:
|
||||
elif op.op is UOps.SINK:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
|
||||
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
|
||||
return fixup_ast(self.ast)
|
||||
# NOTE: rewrite with an empty PatternMatcher to dedup UOps
|
||||
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([]))
|
||||
|
||||
# **** this is the lowerer ****
|
||||
|
||||
@@ -763,8 +772,8 @@ class Kernel:
|
||||
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
|
||||
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
|
||||
diskcache_put(table_name, str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
|
||||
from test.external.process_replay.helpers import get_process_replay_ctx
|
||||
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx()))
|
||||
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
@@ -778,30 +787,28 @@ class Kernel:
|
||||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if not uop.has_st or uop in sts: return
|
||||
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
|
||||
# restore globals from the two stage reduce
|
||||
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
|
||||
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
return
|
||||
for x in src: _assert_valid_uop(x, st, sts)
|
||||
for x in uop.src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
|
||||
elif op is UOps.SWIZZLE: st = arg
|
||||
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1]))
|
||||
# movementops are pushed to SHAPETRACKER and SWIZZLE
|
||||
elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
# movementops are pushed to the edges with SHAPETRACKER
|
||||
# elementwise inherits shape
|
||||
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
|
||||
for x in src:
|
||||
if x.has_st and sts[x].shape != st.shape:
|
||||
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
|
||||
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
|
||||
assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, UOps.ASSIGN, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
||||
if not all_same(shapes:=[x.shape for x in src_sts]):
|
||||
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
||||
raise AssertionError(f"found implicit expand {sizes}")
|
||||
sts[uop] = st
|
||||
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
|
||||
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
# the job of the lowerer is to do indexing
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import List, Tuple, cast, Optional, Dict
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
|
||||
|
||||
# ***** indexing *****
|
||||
|
||||
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
@@ -34,99 +38,91 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
|
||||
idx //= dims[c]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
class IndependentLowerer:
|
||||
def lower(self, ast:UOp, opts:Renderer) -> UOp:
|
||||
self.output_count = len(ast.src)
|
||||
@dataclass(frozen=True)
|
||||
class IndexContext:
|
||||
idxs: List[UOp]
|
||||
ridxs: List[UOp]
|
||||
|
||||
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
|
||||
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
||||
full_shape = ast.full_shape
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
first_output_st: ShapeTracker = ast.src[0].st_arg
|
||||
# if there's no reduce, this is first_upcasted
|
||||
first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
|
||||
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
|
||||
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
||||
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
|
||||
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
|
||||
full_shape = ast.full_shape
|
||||
first_upcasted = len(full_shape)-ki.upcasted
|
||||
first_output_st: ShapeTracker = ast.src[0].st_arg
|
||||
# if there's no reduce, this is first_upcasted
|
||||
first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
|
||||
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
|
||||
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
|
||||
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
|
||||
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
|
||||
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
|
||||
global_dims = first_reduce-ki.local_dims
|
||||
|
||||
if opts.has_local:
|
||||
if ki.dont_use_locals:
|
||||
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
||||
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
||||
else:
|
||||
# define indexes for GPU-like execution
|
||||
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
if opts.has_local:
|
||||
if ki.dont_use_locals:
|
||||
assert ki.local_dims == 0, "can't use locals if there's no local dims"
|
||||
idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
self.idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
# define indexes for GPU-like execution
|
||||
idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
self.idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
# reduce loops
|
||||
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
self.ridxs = self.idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
self.uop_cache: Dict[UOp, UOp] = {}
|
||||
return self.to_uop(ast)
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
def to_uop(self, x:UOp) -> UOp:
|
||||
if uop:=self.uop_cache.get(x, None): return uop
|
||||
ret = self._to_uop(x)
|
||||
self.uop_cache[x] = ret
|
||||
return ret
|
||||
# ***** lowering (given index) *****
|
||||
|
||||
def _to_uop(self, x:UOp) -> UOp:
|
||||
if x.op in BUFFER_UOPS:
|
||||
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0))
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, dtypes.void, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
return UOp(UOps.STORE, dtypes.void, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
|
||||
def lower_reduce_axis(ctx: IndexContext, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
|
||||
in_uops = tuple(self.to_uop(y) for y in x.src)
|
||||
if x.op is UOps.WMMA:
|
||||
upcast_axes = x.arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
||||
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
|
||||
if x.op is UOps.REDUCE_AXIS:
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = in_uops[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
return x if x.src == in_uops else UOp(x.op, x.dtype, in_uops, x.arg)
|
||||
def lower_load_store(ctx: IndexContext, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
|
||||
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.SHAPETRACKER)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
|
||||
@@ -143,7 +143,6 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
|
||||
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
|
||||
|
||||
def fold_unrolled_divs(divs:UOp):
|
||||
@@ -163,16 +162,31 @@ def fold_unrolled_divs(divs:UOp):
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
|
||||
|
||||
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
|
||||
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
|
||||
# returns x0 + x1 + ... in such case, or None if not
|
||||
changed, ret = False, []
|
||||
for u in _get_chain(X, BinaryOps.ADD):
|
||||
# assumed the const is the last src of MUL
|
||||
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
|
||||
changed = True
|
||||
u = u.src[0]
|
||||
if not (is_irreducible(u) and u.vmin >= 0): return None
|
||||
ret.append(u)
|
||||
return functools.reduce(operator.add, ret) if changed else None
|
||||
|
||||
def is_increasing(f:UOp):
|
||||
# is f a monotonically increasing function regards its input
|
||||
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
|
||||
if f.op is UOps.CONST or is_irreducible(f): return True
|
||||
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
|
||||
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
|
||||
return False # False if not sure
|
||||
|
||||
def replace_uop(uop:UOp, old:UOp, new:UOp):
|
||||
# replace all `old` in `uop` to `new`
|
||||
return new if uop is old else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
|
||||
return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
|
||||
|
||||
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
|
||||
# if it's X <= c, returns X, True, c
|
||||
@@ -196,30 +210,50 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
expr, is_upper, c = parse_valid(stmt)
|
||||
bounds[expr][int(is_upper)] = c
|
||||
|
||||
# simplify idx given that valid is True
|
||||
for uop,v in bounds.items():
|
||||
# some expr has lower bound > upper bound -> valid is an empty set
|
||||
if v[0] is not None and v[1] is not None and v[0] > v[1]:
|
||||
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
|
||||
bound = uop.const_like(uop.vmin if v[0] is None else v[0]), uop.const_like(uop.vmax if v[1] is None else v[1])
|
||||
new = UOp(UOps.DEFINE_VAR, uop.dtype, (), ("fake", bound[0], bound[1]))
|
||||
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
|
||||
if newidx.key != idx.key: idx = newidx
|
||||
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
newidxs: List[List[UOp]] = [[], []]
|
||||
for variable in _get_chain(uop, BinaryOps.ADD):
|
||||
new = UOp(UOps.DEFINE_VAR, variable.dtype, (), ("fake", 1, variable.vmax))
|
||||
newidx = replace_uop(graph_rewrite(replace_uop(idx, variable, new), constant_folder), new, variable)
|
||||
newidxs[0].append(newidx.src[0])
|
||||
newidxs[1].append(newidx.src[1])
|
||||
|
||||
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
|
||||
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
|
||||
|
||||
else:
|
||||
new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1])
|
||||
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
|
||||
if newidx.key != idx.key: idx = newidx
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
X, is_upper, c = parse_valid(stmt)
|
||||
if is_upper:
|
||||
# X <= c, check if it's out of bound when X = c+1
|
||||
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
|
||||
if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c+1)), constant_folder).vmin >= b:
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
else:
|
||||
# X >= c, check if it's negative when X = c-1
|
||||
for i in idx.src:
|
||||
if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c-1)), constant_folder).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
break
|
||||
X, is_upper_bound, c = parse_valid(stmt)
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
|
||||
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
|
||||
testidx = graph_rewrite(testidx, constant_folder)
|
||||
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
continue
|
||||
|
||||
# if X <= c, check if it's out of bound when X = c+1
|
||||
# if X >= c, check if it's out of bound when X = c-1
|
||||
test_value = c + 1 if is_upper_bound else c - 1
|
||||
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
|
||||
if is_increasing(i):
|
||||
rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder)
|
||||
if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt)
|
||||
|
||||
if drop_stmt or idx.key != start_idx.key:
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
|
||||
@@ -312,6 +346,8 @@ constant_folder = PatternMatcher([
|
||||
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
|
||||
# self ASSIGN is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
|
||||
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
|
||||
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
|
||||
@@ -419,6 +455,9 @@ constant_folder = PatternMatcher([
|
||||
# generic lt folding
|
||||
(UPat.var("x").lt(UPat.cvar("c", vec=False)),
|
||||
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
|
||||
# canonicalize a simplex with positive coefficients > 0
|
||||
# not x < 1 -> X > 0
|
||||
(UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None),
|
||||
# ** div **
|
||||
# # div folding
|
||||
(UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c:
|
||||
@@ -466,6 +505,9 @@ constant_folder = PatternMatcher([
|
||||
# ** move add consts to end (NOTE: this is still happening before constant folding) **
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
|
||||
# ** move mul consts to end (NOTE: this is still happening before constant folding) **
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
|
||||
])
|
||||
|
||||
# *** uop expander ***
|
||||
@@ -706,8 +748,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
|
||||
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE}
|
||||
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
def push(u:UOp):
|
||||
# assign priorities
|
||||
def get_priority(u:UOp):
|
||||
priority = 0
|
||||
# prefer ranges that depend on the least number of independent ranges
|
||||
if u.op is UOps.RANGE and u.arg[1]:
|
||||
@@ -717,7 +759,20 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
# prefer uops that are loop children
|
||||
else:
|
||||
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
|
||||
heapq.heappush(queue, (priority, u))
|
||||
return priority
|
||||
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
|
||||
|
||||
# prevent priority inversion
|
||||
@functools.lru_cache(None)
|
||||
def fix_priority(u:UOp, lowest_priority):
|
||||
if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}:
|
||||
priorities[u] = min(priorities[u], lowest_priority)
|
||||
if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here)
|
||||
for x in u.src: fix_priority(x, priorities[u])
|
||||
fix_priority(sink, 0)
|
||||
|
||||
queue:List[Tuple[int, UOp]] = []
|
||||
def push(u:UOp): heapq.heappush(queue, (priorities[u], u))
|
||||
|
||||
for u in children:
|
||||
if in_degree[u] == 0: push(u)
|
||||
@@ -726,7 +781,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
_uops: List[UOp] = []
|
||||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(f"{p:5d}",x)
|
||||
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
|
||||
if x in scope_children: scope_end[x] = x
|
||||
if x.op is UOps.DEFINE_ACC:
|
||||
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
||||
@@ -745,12 +800,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
|
||||
|
||||
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
|
||||
if not skip_check:
|
||||
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.REDUCE_AXIS, UOps.SHAPETRACKER}])
|
||||
try:
|
||||
type_verify(_uops)
|
||||
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"
|
||||
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
|
||||
assert not any(x.dtype == dtypes.pyint for x in _uops), "can't return UOp with pyint"
|
||||
# TODO: this should be enabled, and the valid clause should be removed
|
||||
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
||||
# NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL
|
||||
|
||||
@@ -108,7 +108,8 @@ class Buffer:
|
||||
(">" if self.options is None else f" {self.options=}>")
|
||||
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
|
||||
# zero copy with as_buffer (disabled by default due to use after free)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
|
||||
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
|
||||
return self.allocator.as_buffer(self._buf)
|
||||
assert not force_zero_copy, "force zero copy was passed, but copy is required"
|
||||
return self.copyout(memoryview(bytearray(self.nbytes)))
|
||||
def copyin(self, mv:memoryview):
|
||||
|
||||
@@ -25,17 +25,20 @@ class DType:
|
||||
class ImageDType(DType):
|
||||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
|
||||
base: DType
|
||||
local: bool = False # images are never local
|
||||
def scalar(self): return self.base
|
||||
def vec(self, sz:int): return self.base.vec(sz)
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
||||
class PtrDType(DType):
|
||||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __init__(self, dt:DType, local=False):
|
||||
self.base, self.local = dt, local
|
||||
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
|
||||
def __hash__(self): return super().__hash__()
|
||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||
def __ne__(self, dt): return not (self == dt)
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()})"
|
||||
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
@@ -181,7 +181,7 @@ class ExecItem:
|
||||
lds_est = sym_infer(self.prg.lds_estimate, var_vals)
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
||||
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
|
||||
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
|
||||
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
|
||||
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
|
||||
self.prg.first_run = False
|
||||
|
||||
@@ -110,7 +110,7 @@ reduceop_fusor = PatternMatcher([
|
||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
@@ -139,7 +139,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
val, var_val = val.unbind()
|
||||
var_vals[val] = var_val
|
||||
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
||||
return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val)
|
||||
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
# otherwise, it's a load and we add it to the inputs
|
||||
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
|
||||
@@ -157,9 +157,10 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_uops[0]
|
||||
if buf.op is MetaOps.ASSIGN: return cache.setdefault((buf, st), UOp(UOps.ASSIGN, dtype, (in_uops[1].src[0], in_uops[0])))
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
112
tinygrad/ops.py
112
tinygrad/ops.py
@@ -103,19 +103,23 @@ class UOps(FastEnum):
|
||||
VALID = auto()
|
||||
SPECIAL = auto()
|
||||
NOOP = auto()
|
||||
GEP = auto()
|
||||
|
||||
# math ops
|
||||
CAST = auto()
|
||||
BITCAST = auto()
|
||||
VECTORIZE = auto()
|
||||
ALU = auto()
|
||||
REDUCE = auto()
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# helper ops
|
||||
GEP = auto()
|
||||
VECTORIZE = auto()
|
||||
CAST = auto()
|
||||
BITCAST = auto()
|
||||
|
||||
# loads before math
|
||||
LOAD = auto()
|
||||
|
||||
# math ops
|
||||
ALU = auto()
|
||||
WMMA = auto()
|
||||
|
||||
# memory/assignment ops
|
||||
LOAD = auto()
|
||||
# assignment ops
|
||||
STORE = auto()
|
||||
ASSIGN = auto()
|
||||
|
||||
@@ -128,7 +132,7 @@ class UOps(FastEnum):
|
||||
ENDRANGE = auto()
|
||||
ENDIF = auto()
|
||||
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
|
||||
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
|
||||
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
@@ -144,7 +148,7 @@ class UOp(MathTrait):
|
||||
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
|
||||
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> Optional[ShapeTracker]:
|
||||
if not self.has_st: return None
|
||||
@@ -170,16 +174,14 @@ class UOp(MathTrait):
|
||||
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
|
||||
# *** uop syntactic sugar
|
||||
@property
|
||||
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
ret = self.src[self.st_loc]
|
||||
ret = self.src[0 if self.op is UOps.VALID else 1]
|
||||
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
|
||||
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UOp.const(self.dtype, b)
|
||||
def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -215,25 +217,22 @@ class UOp(MathTrait):
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType):
|
||||
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, UOp.const(dtype, min_val), UOp.const(dtype, max_val)))
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def reduce(self, op, *rng): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op is UOps.SHAPETRACKER: return self.arg.shape
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
||||
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr)
|
||||
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
@@ -259,8 +258,7 @@ class UOp(MathTrait):
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[ConstType, ConstType]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.arg:
|
||||
return self.arg[1].arg, self.arg[2].arg if self.arg[2].op is UOps.CONST else dtypes.max(self.dtype)
|
||||
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
|
||||
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
if self.op is UOps.EXPAND: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
@@ -331,7 +329,7 @@ def exec_alu(op:Op, dtype:DType, operands):
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)
|
||||
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
@@ -380,8 +378,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
|
||||
def get_location() -> Tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the UPat
|
||||
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@functools.lru_cache(None)
|
||||
@@ -406,7 +404,7 @@ class UPat(MathTrait):
|
||||
# repeat if it's a UPat
|
||||
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
|
||||
|
||||
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
||||
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
|
||||
self.location = location or get_location()
|
||||
|
||||
if custom_early_reject is not None: self.early_reject = custom_early_reject
|
||||
@@ -459,7 +457,7 @@ class UPat(MathTrait):
|
||||
(self.dtype is not None and uop.dtype not in self.dtype) or \
|
||||
(self.arg is not None and self.arg != uop.arg) or \
|
||||
(self.op is not None and uop.op not in self.op) or \
|
||||
(self.allowed_len != 0 and len(uop.src) != self.allowed_len): return []
|
||||
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
|
||||
if self.src is None: return [store]
|
||||
res: List[Dict[str, UOp]] = []
|
||||
for vp in self.src:
|
||||
@@ -488,11 +486,11 @@ class PatternMatcher:
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
|
||||
return None
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
@@ -512,7 +510,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
@@ -521,7 +519,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None:
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
@@ -536,25 +534,26 @@ if TRACK_MATCH_STATS:
|
||||
import atexit, pickle
|
||||
@atexit.register
|
||||
def print_match_stats():
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
with open("/tmp/rewrites.pkl", "wb") as f:
|
||||
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
|
||||
pickle.dump(contexts, f)
|
||||
if getenv("VIZ"):
|
||||
import viz.serve
|
||||
viz.serve.main()
|
||||
return viz.serve.main()
|
||||
ret = [0,0,0.0,0.0]
|
||||
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
|
||||
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
|
||||
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
|
||||
ret = [x+y for x,y in zip(ret, v)]
|
||||
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
|
||||
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
class RewriteContext:
|
||||
def __init__(self, pm):
|
||||
def __init__(self, pm, ctx):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
@@ -563,33 +562,36 @@ class RewriteContext:
|
||||
if found := self.nodes.get(replace_source): self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
|
||||
return RewriteContext(pm).rewrite(sink)
|
||||
return RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)),
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], UOp) and isinstance(x.arg[2], UOp)),
|
||||
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
|
||||
|
||||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# no pyint allowed here!
|
||||
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.SHAPETRACKER, src=()), lambda: True),
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(),)), lambda: True),
|
||||
|
||||
(UPat(UOps.CONST, name="x"),
|
||||
lambda x: x.dtype == x.dtype.scalar() and (isinstance(x.arg, Variable) and x.src) or (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
(UPat(UOps.VALID, dtypes.bool, (UPat(UOps.SHAPETRACKER),)), lambda: True),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER))), lambda: True),
|
||||
@@ -616,13 +618,13 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True),
|
||||
(UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True),
|
||||
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
|
||||
|
||||
# early WMMA has 2 args, <x, w>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True),
|
||||
# late WMMA has 3 args, <x, w, acc>
|
||||
# all WMMA has 3 args, <x, w, acc>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
|
||||
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
|
||||
|
||||
# if has a <gate, barrier>
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
@@ -636,7 +638,7 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
|
||||
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
||||
(UPat(UOps.SINK), lambda: True),
|
||||
(UPat(UOps.SINK, dtypes.void), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
|
||||
@@ -34,7 +34,7 @@ class Program:
|
||||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg))
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2]))
|
||||
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
||||
if u.op is UOps.SPECIAL:
|
||||
|
||||
@@ -56,14 +56,14 @@ class CStyleLanguage(Renderer):
|
||||
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
|
||||
|
||||
# returns a str expression of the loaded value with the output type
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
|
||||
return f"read_imagef({buf_name}, smp, {idx})"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
|
||||
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
|
||||
if output_dtype.count > 1:
|
||||
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
|
||||
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
|
||||
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
||||
@@ -78,14 +78,14 @@ class CStyleLanguage(Renderer):
|
||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
|
||||
def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
|
||||
return f"write_imagef({buf_name}, {idx}, {var_name});"
|
||||
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
|
||||
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
|
||||
if var_dtype.count > 1:
|
||||
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix
|
||||
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
@@ -124,8 +124,9 @@ class CStyleLanguage(Renderer):
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
# mark DEFINE_GLOBAL buf as writable
|
||||
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
|
||||
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
|
||||
else:
|
||||
if uop is UOps.RANGE:
|
||||
@@ -150,7 +151,7 @@ class CStyleLanguage(Renderer):
|
||||
bufs[u] = (args[0], (dtype,False))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
@@ -226,6 +227,11 @@ class ClangRenderer(CStyleLanguage):
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "GPU"
|
||||
|
||||
code_for_op = {**CStyleLanguage().code_for_op,
|
||||
#UnaryOps.SQRT: lambda x,dtype: f"native_sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"native_recip({x})",
|
||||
#UnaryOps.EXP2: lambda x,dtype: f"native_exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"native_log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"native_sin({x})"} if getenv("NATIVE_MATH") else CStyleLanguage().code_for_op
|
||||
|
||||
# language options
|
||||
kernel_prefix = "__kernel "
|
||||
buffer_prefix = "__global "
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, ctypes, functools, mmap, struct, array, decimal, math
|
||||
from types import SimpleNamespace
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from typing import Tuple, List, Any, cast
|
||||
from tinygrad.device import BufferOptions, HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState, hcq_command
|
||||
from tinygrad.runtime.autogen import kgsl, adreno, libc
|
||||
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
|
||||
@@ -9,6 +9,8 @@ from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, DEBUG, fromimport
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
|
||||
|
||||
def _qreg_exec(reg, __val=0, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
__val |= (getattr(adreno, f'{reg[4:]}_{k.upper()}') if v else 0) if type(v) is bool else (v << getattr(adreno, f'{reg[4:]}_{k.upper()}__SHIFT'))
|
||||
@@ -123,39 +125,38 @@ class QCOMComputeQueue(HWComputeQueue):
|
||||
qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total))
|
||||
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_SHADER, num_unit=prg.kernargs_alloc_size // 4),
|
||||
state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4),
|
||||
*data64_le(args_state.ptr))
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128),
|
||||
*data64_le(prg.lib_gpu.va_addr))
|
||||
|
||||
self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc,
|
||||
qreg.a6xx_hlsq_cs_cntl(constlen=prg.kernargs_alloc_size // 4, enabled=True))
|
||||
self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, qreg.a6xx_hlsq_cs_cntl(constlen=1024 // 4, enabled=True))
|
||||
|
||||
self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_hw_stack_offset(prg.hw_stack_offset))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, qreg.a6xx_sp_cs_instrlen(prg.image_size // 4))
|
||||
|
||||
if hasattr(args_state, 'samplers_ptr'):
|
||||
if args_state.prg.samp_cnt > 0:
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.samplers_cnt),
|
||||
*data64_le(args_state.samplers_ptr.va_addr))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.samplers_ptr.va_addr))
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
|
||||
*data64_le(args_state.ptr + args_state.prg.samp_off))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
|
||||
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.device._border_color_base()))
|
||||
|
||||
if hasattr(args_state, 'descriptors_ptr'):
|
||||
if args_state.prg.tex_cnt > 0:
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.descriptors_cnt),
|
||||
*data64_le(args_state.descriptors_ptr.va_addr))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.descriptors_ptr.va_addr))
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt),
|
||||
*data64_le(args_state.ptr + args_state.prg.tex_off))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
|
||||
|
||||
if hasattr(args_state, 'ibos_ptr'):
|
||||
if args_state.prg.ibo_cnt > 0:
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_SHADER, num_unit=args_state.ibos_cnt),
|
||||
*data64_le(args_state.ibos_ptr.va_addr))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ibos_ptr.va_addr))
|
||||
state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt),
|
||||
*data64_le(args_state.ptr + args_state.prg.ibo_off))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ptr + args_state.prg.ibo_off))
|
||||
|
||||
self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
|
||||
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.samplers_cnt, ntex=args_state.descriptors_cnt, nibo=args_state.ibos_cnt))
|
||||
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
|
||||
self.cmd(adreno.CP_RUN_OPENCL, 0)
|
||||
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
|
||||
|
||||
@@ -175,46 +176,27 @@ class QCOMComputeQueue(HWComputeQueue):
|
||||
class QCOMArgsState(HCQArgsState):
|
||||
def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
|
||||
super().__init__(ptr, prg, bufs, vals=vals)
|
||||
self.ibos_cnt, self.descriptors_cnt, self.samplers_cnt = 0, 0, 0
|
||||
ctypes.memset(ptr, 0, 1024)
|
||||
ctypes.memset(self.ptr, 0, 1024)
|
||||
|
||||
if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}')
|
||||
|
||||
self.buf_info, self.args_info = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):]
|
||||
|
||||
if len(bufs) + len(vals) != len(prg.buffs_info): raise RuntimeError(f'incorrect args size given={len(bufs)} != want={len(prg.buffs_info)}')
|
||||
self.boffs, self.aoffs = prg.buffs_info[:len(bufs)], prg.buffs_info[len(bufs):]
|
||||
for i, v in enumerate(vals): self.update_var(i, v)
|
||||
for cnst_val, cnst_off, cnst_sz in prg.consts_info:
|
||||
ctypes.memmove(self.ptr + cnst_off, (ctypes.c_int8 * cnst_sz).from_buffer_copy(cnst_val.to_bytes(cnst_sz, byteorder='little')), cnst_sz)
|
||||
|
||||
samplers: List[Any] = []
|
||||
descriptors: List[Any] = []
|
||||
ibos: List[Any] = []
|
||||
self.i2descr: Dict[int, int] = {}
|
||||
self.i2ibo: Dict[int, int] = {}
|
||||
for i, b in enumerate(bufs):
|
||||
if not hasattr(b, 'samplers') and not hasattr(b, 'descriptor') and not hasattr(b, 'ibo'): self.update_buffer(i, b)
|
||||
elif self.boffs[i][1]: ibos, self.i2ibo = [*ibos, *getattr(b, 'ibo')], {**self.i2ibo, i: len(ibos)}
|
||||
else:
|
||||
samplers, descriptors = [*samplers, *getattr(b, 'samplers')], [*descriptors, *getattr(b, 'descriptor')]
|
||||
self.i2descr[i] = len(descriptors) - 1
|
||||
|
||||
def alloc_tex_gpu(data, chunk_size) -> Tuple[HCQBuffer, int]:
|
||||
tex_gpu = self.prg.device.allocator.alloc(len(data) * 4, BufferOptions(nolru=True, cpu_access=True))
|
||||
to_mv(tex_gpu.va_addr, len(data) * 4).cast('I')[:] = array.array('I', data)
|
||||
return tex_gpu, len(data) // chunk_size
|
||||
|
||||
if len(samplers): self.samplers_ptr, self.samplers_cnt = alloc_tex_gpu(samplers, 4)
|
||||
if len(descriptors): self.descriptors_ptr, self.descriptors_cnt = alloc_tex_gpu(descriptors, 16)
|
||||
if len(ibos): self.ibos_ptr, self.ibos_cnt = alloc_tex_gpu(ibos, 16)
|
||||
|
||||
def __del__(self):
|
||||
for ptr in ('samplers_ptr', 'descriptors_ptr', 'ibos_ptr'):
|
||||
if hasattr(self, ptr): self.prg.device.allocator.free((x:=getattr(self, ptr)), x.size, BufferOptions(nolru=True, cpu_access=True))
|
||||
if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
for i, b in enumerate(cast(List[QCOMBuffer], bufs)):
|
||||
if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc)
|
||||
elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo)
|
||||
else: self.update_buffer(i, b)
|
||||
for i, v in enumerate(vals): self.update_var(i, v)
|
||||
|
||||
def update_buffer(self, index:int, buf:HCQBuffer):
|
||||
if (descr:=self.i2descr.get(index, None)) is not None: to_mv(self.descriptors_ptr.va_addr + 16 * descr + 4 * 4, 8).cast('Q')[0] = buf.va_addr
|
||||
elif (ibo:=self.i2ibo.get(index, None)) is not None: to_mv(self.ibos_ptr.va_addr + 16 * ibo + 4 * 4, 8).cast('Q')[0] = buf.va_addr
|
||||
else: to_mv(self.ptr + self.boffs[index][0], 8).cast('Q')[0] = buf.va_addr
|
||||
if self.buf_info[index].type is not BUFTYPE_BUF: to_mv(self.ptr+self.buf_info[index].offset+0x10, 8).cast('Q')[0] = buf.va_addr
|
||||
else: to_mv(self.ptr + self.buf_info[index].offset, 8).cast('Q')[0] = buf.va_addr
|
||||
|
||||
def update_var(self, index:int, val:int): to_mv(self.ptr + self.aoffs[index][0], 8).cast('Q')[0] = val
|
||||
def update_var(self, index:int, val:int): to_mv(self.ptr + self.args_info[index].offset, 8).cast('Q')[0] = val
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, device: QCOMDevice, name: str, lib: bytes):
|
||||
@@ -232,7 +214,7 @@ class QCOMProgram(HCQProgram):
|
||||
self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128)
|
||||
device._ensure_stack_size(self.hw_stack_offset * 4)
|
||||
|
||||
super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=1024)
|
||||
super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10)
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
|
||||
@@ -253,19 +235,34 @@ class QCOMProgram(HCQProgram):
|
||||
self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8)
|
||||
|
||||
# Fill up constants and buffers info
|
||||
self.buffs_info, self.consts_info = [], []
|
||||
self.buf_info, self.consts_info = [], []
|
||||
|
||||
samplers_count = _read_lib(image_desc_off + 0xdc)
|
||||
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samplers_count
|
||||
while (bdoff + 16 <= len(self.lib)):
|
||||
length, _, _, offset_words = struct.unpack("I" * 4, self.lib[bdoff:bdoff+16])
|
||||
# Collect sampler info.
|
||||
self.samp_cnt = _read_lib(image_desc_off + 0xdc)
|
||||
assert self.samp_cnt <= 1, "Up to one sampler supported"
|
||||
if self.samp_cnt:
|
||||
self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode),
|
||||
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0]
|
||||
|
||||
# Collect kernel arguments (buffers) info.
|
||||
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt
|
||||
while bdoff + 32 <= len(self.lib):
|
||||
length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32])
|
||||
if length == 0: break
|
||||
self.buffs_info.append((offset_words * 4, struct.unpack("I", self.lib[bdoff+0x3c:bdoff+0x40])[0] == 0x0))
|
||||
self.buf_info.append(SimpleNamespace(offset=offset_words * 4, type=typ))
|
||||
bdoff += length
|
||||
|
||||
# Setting correct offsets to textures/ibos.
|
||||
self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info)
|
||||
self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt
|
||||
cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off
|
||||
for x in self.buf_info:
|
||||
if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40
|
||||
elif x.type is BUFTYPE_TEX: x.offset, cur_tex_off = cur_tex_off, cur_tex_off + 0x40
|
||||
|
||||
if _read_lib(0xb0) != 0: # check if we have constants.
|
||||
cdoff = _read_lib(0xac)
|
||||
while (cdoff + 40 <= image_offset):
|
||||
while cdoff + 40 <= image_offset:
|
||||
cnst, offset_words, _, is32 = struct.unpack("I", self.lib[cdoff:cdoff+4])[0], *struct.unpack("III", self.lib[cdoff+16:cdoff+28])
|
||||
self.consts_info.append((cnst, offset_words * (sz_bytes:=(2 << is32)), sz_bytes))
|
||||
cdoff += 40
|
||||
@@ -277,6 +274,10 @@ class QCOMProgram(HCQProgram):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
|
||||
|
||||
class QCOMBuffer(HCQBuffer):
|
||||
def __init__(self, va_addr:int, size:int, desc=None, ibo=None, pitch=None, real_stride=None):
|
||||
self.va_addr, self.size, self.desc, self.ibo, self.pitch, self.real_stride = va_addr, size, desc, ibo, pitch, real_stride
|
||||
|
||||
class QCOMAllocator(HCQAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
if options.image is not None:
|
||||
@@ -286,33 +287,37 @@ class QCOMAllocator(HCQAllocator):
|
||||
|
||||
granularity = 128 if options.image.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up(imgw * 4 * options.image.itemsize, 1 << pitchalign) + pitch_add
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
|
||||
texture = self.device._gpu_alloc(pitch * round_up(imgh, 16), kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
|
||||
# Extend HCQBuffer with texture-related info.
|
||||
texture.samplers, texture.descriptor, texture.ibo = [0] * 4, [0] * 16, [0] * 16
|
||||
|
||||
# Compiled sampler (always the same in tinygrad).
|
||||
texture.samplers[0] = qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode)
|
||||
texture.samplers[1] = qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True)
|
||||
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
|
||||
|
||||
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
|
||||
texture.descriptor[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
texture.descriptor[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
|
||||
texture.descriptor[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6)
|
||||
texture.descriptor[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
|
||||
texture.ibo = [texture.descriptor[0] & (~0xffff), *texture.descriptor[1:len(texture.descriptor)]]
|
||||
texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
|
||||
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
|
||||
texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
|
||||
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
|
||||
|
||||
return texture
|
||||
|
||||
return self.device._gpu_alloc(size, map_to_cpu=True)
|
||||
|
||||
def copyin(self, dest:HCQBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
|
||||
while src_off < src_size:
|
||||
ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size)
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
if hasattr(qd:=cast(QCOMBuffer, dest), 'pitch'): self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
|
||||
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
|
||||
|
||||
def copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.device.synchronize()
|
||||
ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
if hasattr(qs:=cast(QCOMBuffer, src), 'pitch'): self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
|
||||
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
|
||||
def as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
self.device.synchronize()
|
||||
|
||||
134
viz/index.html
134
viz/index.html
@@ -62,13 +62,11 @@
|
||||
stroke-width: 1.5px;
|
||||
}
|
||||
.graph {
|
||||
grid-column: span 10;
|
||||
width: 70%;
|
||||
position: relative;
|
||||
}
|
||||
.main-container {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(14, 1fr);
|
||||
gap: 8px;
|
||||
display: flex;
|
||||
padding: 12px;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
@@ -79,12 +77,16 @@
|
||||
background-color: #111111;
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
position: relative;
|
||||
}
|
||||
.container > * + * {
|
||||
margin-top: 12px;
|
||||
}
|
||||
.main-container > * + * {
|
||||
margin-left: 8px;
|
||||
}
|
||||
.kernel-list {
|
||||
grid-column: span 1;
|
||||
width: 10%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow-y: auto;
|
||||
@@ -96,9 +98,18 @@
|
||||
margin-top: 4px;
|
||||
}
|
||||
.metadata {
|
||||
grid-column: span 3;
|
||||
width: 20%;
|
||||
overflow-y: auto;
|
||||
}
|
||||
#resize-handle {
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
width: 20px;
|
||||
cursor: w-resize;
|
||||
background-color: transparent;
|
||||
}
|
||||
.rewrite-list {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
@@ -137,12 +148,17 @@
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
.status {
|
||||
color: #EC5D5E;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="main-container">
|
||||
<div class="container kernel-list"></div>
|
||||
<div class="container graph">
|
||||
<div class="status"></div>
|
||||
<svg>
|
||||
<g id="render"></g>
|
||||
</svg>
|
||||
@@ -150,12 +166,16 @@
|
||||
<div class="container metadata"></div>
|
||||
</div>
|
||||
<script>
|
||||
function renderGraph(graph) {
|
||||
const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
function renderGraph(graph, additions) {
|
||||
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: #242424" : "display: none;"});
|
||||
for ([k,u] of Object.entries(graph)) {
|
||||
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
|
||||
for (src of u[2]) {
|
||||
g.setEdge(src, k)
|
||||
g.setEdge(src, k, {curve: d3.curveBasis})
|
||||
}
|
||||
if (additions.includes(parseInt(k))) {
|
||||
g.setParent(k, "addition");
|
||||
}
|
||||
}
|
||||
const svg = d3.select("svg");
|
||||
@@ -170,26 +190,38 @@
|
||||
const render = new dagreD3.render();
|
||||
render(inner, g);
|
||||
}
|
||||
async function checkStatus() {
|
||||
active = true;
|
||||
try {
|
||||
active = (await fetch("/")).ok;
|
||||
} catch {
|
||||
active = false;
|
||||
}
|
||||
document.querySelector('.status').textContent = active ? `` : `Connection to localhost:8000 failed, is VIZ=1 running?`;
|
||||
}
|
||||
var ret = {};
|
||||
var cache = {};
|
||||
var kernels = null;
|
||||
var currentUOp = 0;
|
||||
var currentKernel = 0;
|
||||
var currentRewrite = 0;
|
||||
var expandKernel = false;
|
||||
async function main() {
|
||||
checkStatus();
|
||||
// ** kernel list
|
||||
if (kernels == null) {
|
||||
kernels = await (await fetch("/kernels")).json();
|
||||
currentKernel = 0;
|
||||
}
|
||||
const kernelList = document.querySelector(".container.kernel-list");
|
||||
kernelList.innerHTML = "";
|
||||
Object.entries(kernels).forEach(([k,uopRewrites], i) => {
|
||||
kernelUl = Object.assign(document.createElement("ul"), { innerHTML: `<p>${k}</p>`, key: `kernel-${i}`, className: i === currentKernel ? "active" : "" });
|
||||
uopRewrites.forEach((u, j) => {
|
||||
kernels.forEach((k, i) => {
|
||||
kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "" });
|
||||
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name})
|
||||
kernelUl.appendChild(p)
|
||||
k.ctxs.forEach((u, j) => {
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
|
||||
if (i !== currentKernel) {
|
||||
rwUl.style.display = "none";
|
||||
}
|
||||
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
|
||||
rwUl.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
currentUOp = j;
|
||||
@@ -199,11 +231,16 @@
|
||||
}
|
||||
kernelUl.appendChild(rwUl)
|
||||
})
|
||||
kernelUl.onclick = (e) => {
|
||||
if (i === currentKernel) return;
|
||||
p.onclick = (e) => {
|
||||
if (i === currentKernel) {
|
||||
expandKernel = !expandKernel;
|
||||
main();
|
||||
return;
|
||||
}
|
||||
currentKernel = i;
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
expandKernel = true;
|
||||
main();
|
||||
}
|
||||
kernelList.appendChild(kernelUl);
|
||||
@@ -217,17 +254,46 @@
|
||||
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
|
||||
cache[cacheKey] = ret;
|
||||
}
|
||||
renderGraph(ret[0].graphs[currentRewrite][0]);
|
||||
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
|
||||
const metadata = document.querySelector(".container.metadata");
|
||||
metadata.innerHTML = "";
|
||||
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
|
||||
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
|
||||
const resizeHandle = document.getElementById("resize-handle");
|
||||
|
||||
let startX;
|
||||
let containerWidth;
|
||||
let metadataWidth;
|
||||
resizeHandle.addEventListener("mousedown", (e) => {
|
||||
e.preventDefault();
|
||||
metadata.style.userSelect = "none";
|
||||
startX = e.clientX;
|
||||
containerWidth = document.querySelector(".main-container").getBoundingClientRect().width;
|
||||
metadataWidth = metadata.getBoundingClientRect().width;
|
||||
document.documentElement.addEventListener("mousemove", resize, false);
|
||||
document.documentElement.addEventListener("mouseup", stopResize, false);
|
||||
});
|
||||
function resize(e) {
|
||||
const change = e.clientX - startX;
|
||||
const newWidth = ((metadataWidth-change) / containerWidth) * 100;
|
||||
if (newWidth >= 20 && newWidth <= 50) {
|
||||
metadata.style.width = `${newWidth}%`;
|
||||
document.querySelector(".graph").style.width = `${100-newWidth-10}%`;
|
||||
}
|
||||
}
|
||||
function stopResize(e) {
|
||||
document.documentElement.removeEventListener("mousemove", resize, false);
|
||||
document.documentElement.removeEventListener("mouseup", stopResize, false);
|
||||
metadata.style.userSelect = "initial";
|
||||
}
|
||||
|
||||
ret[0].extra[currentRewrite].forEach((e, i) => {
|
||||
if (e.length == 0) return;
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${e}</code>`, className: "code-block" });
|
||||
metadata.appendChild(pre);
|
||||
})
|
||||
if (ret[0].code !== "") {
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${ret[0].code}</code>`, className: "code-block" });
|
||||
if (kernels[currentKernel].code !== "") {
|
||||
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${kernels[currentKernel].code}</code>`, className: "code-block" });
|
||||
metadata.appendChild(pre);
|
||||
}
|
||||
// ** rewrite list
|
||||
@@ -263,7 +329,30 @@
|
||||
}
|
||||
}
|
||||
document.addEventListener("keydown", async function(event) {
|
||||
// up and down change the UOp from the list
|
||||
// up and down change the UOp or kernel from the list
|
||||
if (!expandKernel) {
|
||||
if (event.key == "ArrowUp") {
|
||||
event.preventDefault()
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
currentKernel = Math.max(0, currentKernel-1)
|
||||
return main()
|
||||
}
|
||||
if (event.key == "ArrowDown") {
|
||||
event.preventDefault()
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
|
||||
return main()
|
||||
}
|
||||
}
|
||||
if (event.key == "Enter") {
|
||||
event.preventDefault()
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
expandKernel = !expandKernel;
|
||||
main();
|
||||
}
|
||||
if (event.key == "ArrowUp") {
|
||||
event.preventDefault()
|
||||
currentRewrite = 0;
|
||||
@@ -273,7 +362,7 @@
|
||||
if (event.key == "ArrowDown") {
|
||||
event.preventDefault()
|
||||
currentRewrite = 0;
|
||||
const totalUOps = Array.from(Object.values(kernels))[currentKernel].length-1;
|
||||
const totalUOps = kernels[currentKernel].ctxs.length-1;
|
||||
currentUOp = Math.min(totalUOps, currentUOp+1)
|
||||
main()
|
||||
}
|
||||
@@ -290,6 +379,7 @@
|
||||
main()
|
||||
}
|
||||
})
|
||||
setInterval(checkStatus, 5000);
|
||||
main()
|
||||
</script>
|
||||
</body>
|
||||
|
||||
96
viz/serve.py
96
viz/serve.py
@@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Tuple
|
||||
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, asdict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad import Device
|
||||
@@ -11,6 +12,37 @@ from tinygrad.engine.graph import uops_colors, word_wrap
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import full_ast_rewrite
|
||||
|
||||
# **** /graph - detailed UOp + rewrites
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str
|
||||
graphs: List[UOp] # snapshot of the entire AST after each rewrite
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
additions: List[List[int]]
|
||||
@staticmethod
|
||||
def from_ctx(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
uops: List[UOp] = [ctx.sink]
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
additions: List[List[int]] = [[]]
|
||||
seen_replaces: Dict[bytes, UOp] = {}
|
||||
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
|
||||
if pattern.location[0].split("/")[-1] == "ops.py": continue
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
seen_replaces[first.key] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
# update ret data
|
||||
additions.append([id(x) for x in rewritten.sparents])
|
||||
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, uops, diffs, extra, additions)
|
||||
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
|
||||
|
||||
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
|
||||
@@ -22,51 +54,33 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UOpRet:
|
||||
loc: str
|
||||
graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
|
||||
extra: List[List[str]] # these become code blocks in the UI
|
||||
|
||||
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
|
||||
if (found:=replaces.get(base.key)): return found
|
||||
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
|
||||
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
|
||||
return ret
|
||||
|
||||
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
|
||||
uops: List[UOp] = [ctx.sink]
|
||||
graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)]
|
||||
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
|
||||
extra: List[List[str]] = [[str(ctx.sink)]]
|
||||
seen_replaces: Dict[bytes, UOp] = {}
|
||||
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
|
||||
if pattern.location[0].split("/")[-1] == "ops.py": continue
|
||||
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
|
||||
seen_replaces[first.key] = rewritten
|
||||
new_sink = replace_uop(uops[-1], {**seen_replaces})
|
||||
# sanity check
|
||||
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
|
||||
# update ret data
|
||||
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
|
||||
graphs.append((new_sink, uops[-1], rewritten, first))
|
||||
uops.append(new_sink)
|
||||
extra.append([str(new_sink)])
|
||||
return UOpRet(ctx.loc, graphs, diffs, extra)
|
||||
# **** /kernels - Overview of the kernel
|
||||
|
||||
def get_ctx_groups(contexts:List[TrackedRewriteContext]) -> Dict[str, Tuple[List[TrackedRewriteContext], str]]:
|
||||
ctx_groups: Dict[str, Tuple[List[TrackedRewriteContext], str]] = {}
|
||||
@dataclass(frozen=True)
|
||||
class KernelRet:
|
||||
name: str
|
||||
code: str
|
||||
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
|
||||
def to_json(self) -> Dict:
|
||||
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
|
||||
|
||||
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
|
||||
ret: Dict[str, KernelRet] = {}
|
||||
kernel_name = ""
|
||||
code = ""
|
||||
for ctx in contexts:
|
||||
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
|
||||
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
|
||||
elif ctx.kernel_name is not None: kernel_name = ctx.kernel_name
|
||||
if ctx_groups.get(k:=to_function_name(kernel_name)) is None: ctx_groups[k] = ([], code)
|
||||
# TODO: make ansi play nice with css
|
||||
ctx_groups[to_function_name(kernel_name)][0].append(ctx)
|
||||
return ctx_groups
|
||||
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
|
||||
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
|
||||
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
|
||||
return list(ret.values())
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
@@ -87,20 +101,18 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
ctx_groups = get_ctx_groups(contexts)
|
||||
ret = json.dumps({k:[x.loc for x in v[0]] for k,v in ctx_groups.items()}).encode()
|
||||
kernels = load_kernels(contexts)
|
||||
ret = json.dumps([x.to_json() for x in kernels]).encode()
|
||||
elif url.path == "/graph":
|
||||
query = parse_qs(url.query)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
|
||||
ctx_groups = get_ctx_groups(contexts)
|
||||
group, code = ctx_groups[list(ctx_groups.keys())[int(query["kernel_idx"][0])]]
|
||||
g = create_graph(group[int(query["uop_idx"][0])])
|
||||
rest = [x.loc for x in group]
|
||||
ret = json.dumps(({"loc": g.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in g.graphs],
|
||||
"diffs": g.diffs, "extra": g.extra, "code": code}, rest)).encode()
|
||||
kernels = load_kernels(contexts)
|
||||
k = kernels[int(query["kernel_idx"][0])]
|
||||
g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])])
|
||||
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode()
|
||||
else:
|
||||
self.send_response(404)
|
||||
ret = b""
|
||||
|
||||
@@ -6,10 +6,10 @@ from tinygrad import Tensor
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.helpers import CI, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
|
||||
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from viz.serve import create_graph, get_ctx_groups
|
||||
from viz.serve import UOpRet, load_kernels
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def tearDown(self) -> None:
|
||||
@@ -19,12 +19,11 @@ class TestViz(unittest.TestCase):
|
||||
def assert_valid_ctx(self, contexts):
|
||||
assert len(contexts) != 0
|
||||
for i,ctx in enumerate(contexts):
|
||||
try: ret = create_graph(ctx)
|
||||
try: ret = UOpRet.from_ctx(ctx)
|
||||
except Exception as e:
|
||||
print(colored(f"failed to create graph for ctx {i}", "red"))
|
||||
raise e
|
||||
rewrites = [x[0] for x in ret.graphs]
|
||||
for j,(x,y) in enumerate(zip(rewrites, rewrites[1:])):
|
||||
for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])):
|
||||
if x.key == y.key:
|
||||
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
|
||||
|
||||
@@ -45,10 +44,10 @@ class TestViz(unittest.TestCase):
|
||||
schedule2 = Tensor.randn(4, 4).contiguous().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
ret = get_ctx_groups(contexts)
|
||||
assert len(ret.keys()) == 2
|
||||
assert all(len([x for x in ctxs if "schedule" in x.loc]) != 0 for ctxs,_ in ret.values())
|
||||
assert all(len([x for x in ctxs if "uopgraph" in x.loc]) != 0 for ctxs,_ in ret.values())
|
||||
ret = load_kernels(contexts)
|
||||
assert len(ret) == 2
|
||||
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
|
||||
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
|
||||
|
||||
def test_gemm_diff(self):
|
||||
x = Tensor.empty(64, 64).realize()
|
||||
@@ -67,8 +66,8 @@ class TestViz(unittest.TestCase):
|
||||
])
|
||||
ret = graph_rewrite(sink, pm)
|
||||
if DEBUG >= 4: print_diff(sink, ret)
|
||||
g = create_graph(contexts[0])
|
||||
assert g.graphs[-1][0].key == ret.key
|
||||
g = UOpRet.from_ctx(contexts[0])
|
||||
assert g.graphs[-1].key == ret.key
|
||||
self.assert_valid_ctx(contexts)
|
||||
|
||||
def test_devectorize_viz(self):
|
||||
@@ -115,5 +114,28 @@ class TestViz(unittest.TestCase):
|
||||
simple_pm.rewrite(UOp.const(dtypes.int, 2))
|
||||
self.assertEqual(len(contexts), 0)
|
||||
|
||||
def test_dedup_ast(self):
|
||||
contexts.clear()
|
||||
a = Tensor.randn(4, 4)+2
|
||||
b = Tensor.randn(4, 4)+2
|
||||
Tensor.schedule(a, b)
|
||||
kernels = load_kernels(contexts)
|
||||
self.assertEqual(len(kernels), 1)
|
||||
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 1)
|
||||
|
||||
def test_no_dedup_different_opts(self):
|
||||
contexts.clear()
|
||||
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
|
||||
s = a.schedule()
|
||||
with Context(NOOPT=1): list(lower_schedule(s.copy()))
|
||||
with Context(NOOPT=0): list(lower_schedule(s.copy()))
|
||||
kernels = load_kernels(contexts)
|
||||
self.assertEqual(len(kernels), 2)
|
||||
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 1)
|
||||
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
|
||||
self.assertEqual(len(schedule_ctxs), 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user