Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-09-23 12:14:59 -07:00
43 changed files with 791 additions and 463 deletions

View File

@@ -45,6 +45,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Run Stable Diffusion
run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run Stable Diffusion with fp16
@@ -148,6 +150,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Run model inference benchmark
run: NV=1 NOCLANG=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
@@ -252,6 +256,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Fuzz Padded Tensor Core GEMM (NV)
run: NV=1 M_START=12 M_STOP=20 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Fuzz Padded Tensor Core GEMM (PTX)
@@ -318,6 +324,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Show off tinybox
run: /opt/rocm/bin/rocm-bandwidth-test
# TODO: unstable on AMD
@@ -420,6 +428,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Train MNIST
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
@@ -471,6 +481,8 @@ jobs:
run: |
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: openpilot compile 0.9.4
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt
- name: openpilot compile 0.9.7
@@ -485,6 +497,8 @@ jobs:
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_4.txt
- name: benchmark openpilot w IMAGE=2 0.9.7
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt
- name: openpilot compile3 0.9.7
run: PYTHONPATH="." QCOM=1 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
- uses: actions/upload-artifact@v4

View File

@@ -197,7 +197,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=37 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)

View File

@@ -6,6 +6,7 @@ sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1"
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"

View File

@@ -3,6 +3,7 @@ import numpy as np
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "NATIVE_MATH" not in os.environ: os.environ["NATIVE_MATH"] = "1"
from tinygrad import fetch, Tensor, TinyJit, Device, Context, GlobalCounters
from tinygrad.helpers import OSX, DEBUG, Timing

View File

@@ -12,7 +12,7 @@ from extra.models.unet import UNetModel, Upsample, Downsample, timestep_embeddin
from examples.stable_diffusion import ResnetBlock, Mid
import numpy as np
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union
from typing import Dict, List, Callable, Optional, Any, Set, Tuple, Union, Type
import argparse, tempfile
from abc import ABC, abstractmethod
from pathlib import Path
@@ -282,19 +282,29 @@ class SDXL:
return self.first_stage_model.decode(1.0 / 0.13025 * x)
class VanillaCFG:
class Guider(ABC):
def __init__(self, scale:float):
self.scale = scale
def prepare_inputs(self, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor,Dict]:
@abstractmethod
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
pass
class VanillaCFG(Guider):
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
c_out = {}
for k in c:
assert k in ["vector", "crossattn", "concat"]
c_out[k] = Tensor.cat(uc[k], c[k], dim=0)
return Tensor.cat(x, x), Tensor.cat(s, s), c_out
def __call__(self, x:Tensor) -> Tensor:
x_u, x_c = x.chunk(2)
x_u, x_c = denoiser(Tensor.cat(x, x), Tensor.cat(s, s), c_out).chunk(2)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
class SplitVanillaCFG(Guider):
def __call__(self, denoiser, x:Tensor, s:Tensor, c:Dict, uc:Dict) -> Tensor:
x_u = denoiser(x, s, uc)
x_c = denoiser(x, s, c)
x_pred = x_u + self.scale*(x_c - x_u)
return x_pred
@@ -302,13 +312,12 @@ class VanillaCFG:
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L21
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/sampling.py#L287
class DPMPP2MSampler:
def __init__(self, cfg_scale:float):
def __init__(self, cfg_scale:float, guider_cls:Type[Guider]=VanillaCFG):
self.discretization = LegacyDDPMDiscretization()
self.guider = VanillaCFG(cfg_scale)
self.guider = guider_cls(cfg_scale)
def sampler_step(self, old_denoised:Optional[Tensor], prev_sigma:Optional[Tensor], sigma:Tensor, next_sigma:Tensor, denoiser, x:Tensor, c:Dict, uc:Dict) -> Tuple[Tensor,Tensor]:
denoised = denoiser(*self.guider.prepare_inputs(x, sigma, c, uc))
denoised = self.guider(denoised)
denoised = self.guider(denoiser, x, sigma, c, uc)
t, t_next = sigma.log().neg(), next_sigma.log().neg()
h = t_next - t
@@ -329,7 +338,7 @@ class DPMPP2MSampler:
return x, denoised
def __call__(self, denoiser, x:Tensor, c:Dict, uc:Dict, num_steps:int, timing=False) -> Tensor:
sigmas = self.discretization(num_steps)
sigmas = self.discretization(num_steps).to(x.device)
x *= Tensor.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)

Binary file not shown.

View File

@@ -7,7 +7,7 @@ import math
# https://github.com/Stability-AI/generative-models/blob/fbdc58cab9f4ee2be7a5e1f2e2787ecd9311942f/sgm/modules/diffusionmodules/util.py#L207
def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half) / half).exp()
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)

View File

@@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
return Tensor(data, requires_grad=False).reshape(tuple(inp.dims))
return Tensor(data.reshape(tuple(inp.dims)), requires_grad=False)
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:

View File

@@ -1,7 +1,7 @@
from typing import List
from extra.models.resnet import ResNet50
from tinygrad import Tensor, Device
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen, _CURRENT_KERNEL
from tinygrad.ops import UOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.lowerer import ast_to_uop
@@ -43,6 +43,7 @@ if __name__ == "__main__":
rewritten_uops = []
for i,(k,u) in enumerate(zip(kernels, uops)):
with Timing(f"rewrite {i:2d} {k.name}{' '*(50-ansilen(k.name))}", enabled=getenv("VERBOSE", 0)):
if getenv("VIZ"): _CURRENT_KERNEL.set(k.name)
rewritten_uops.append(full_graph_rewrite(u, k.opts))
uops = rewritten_uops
if getenv("LINEARIZE", 1):

View File

@@ -17,7 +17,7 @@ def process_replay(outs:List[LazyBuffer], graph:DefaultDict[LBScheduleItem, List
if not os.path.isfile(fp):
shutil.copyfile(fetch(f"https://raw.githubusercontent.com/tinygrad/tinygrad/{ref_schedule}/tinygrad/engine/schedule.py", allow_caching=False), fp)
# create the reference graph
ref_graph, ref_in_degree = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs)
ref_graph, ref_in_degree, _ = importlib.import_module("test.external.process_replay.master_schedule")._graph_schedule(outs)
# compare
diff_schedule([(ref_graph, ref_in_degree), (graph, in_degree)])
@@ -26,7 +26,7 @@ def diff_schedule(s:List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]]
for _,in_degree in s:
for lsi in in_degree:
for buf in lsi.outputs:
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
si_for_buf[buf].append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), tuple(lsi.metadata)))
changed = 0
seen_diffs: Set[bytes] = set()
for buf,si in si_for_buf.items():

View File

@@ -1,5 +1,7 @@
import difflib, logging
from tinygrad.helpers import colored, getenv
from dataclasses import dataclass
import difflib, logging, traceback, subprocess
from typing import Dict, Optional
from tinygrad.helpers import ContextVar, colored, getenv
def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
if not logging.getLogger().hasHandlers(): logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -10,3 +12,17 @@ def print_diff(s0, s1, unified=getenv("UNIFIED_DIFF",1)):
import ocdiff
diff = ocdiff.console_diff(str(s0), str(s1))
logging.info(diff)
@dataclass(frozen=True)
class ProcessReplayContext:
ctx_vars: Dict[str, int]
loc: str = ""
head_sha: str = ""
run_id: Optional[int] = None
def get_process_replay_ctx() -> ProcessReplayContext:
stack = filter(lambda x: "tinygrad" in x.filename and not any(n in x.filename for n in ["engine/schedule.py", "engine/realize.py", \
"codegen/kernel.py", "unittest"]), traceback.extract_stack()[:-1])
loc = "\n".join(traceback.format_list(stack))
try: head_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode()
except Exception: head_sha = ""
return ProcessReplayContext({k:v.value for k,v in ContextVar._cache.items()}, loc, head_sha, getenv("GITHUB_RUN_ID") or None)

View File

@@ -12,8 +12,7 @@ from test.external.process_replay.helpers import print_diff
PAGE_SIZE = 100
REF = os.getenv("GITHUB_REF_NAME", "")
MAX_DIFF_PCT = getenv("PROCESS_REPLAY_MAX_DIFF_PCT", 20)
RUN_ID = os.getenv("GITHUB_RUN_ID", "HEAD")
TABLE_NAME = f"process_replay_{RUN_ID}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}"
TABLE_NAME = f"process_replay_{VERSION}"
os.environ["RUN_PROCESS_REPLAY"] = "0"
early_stop = multiprocessing.Event()
logging.basicConfig(level=logging.INFO, format="%(message)s")
@@ -21,6 +20,7 @@ logging.basicConfig(level=logging.INFO, format="%(message)s")
# user config
ASSERT_DIFF = getenv("ASSERT_PROCESS_REPLAY", int((k:="[run_process_replay]") in os.getenv("COMMIT_MESSAGE", k) or k in os.getenv("PR_TITLE", k)))
SKIP_PROCESS_REPLAY = (k:="[skip_process_replay]") in os.getenv("COMMIT_MESSAGE", "") or k in os.getenv("PR_TITLE", "")
COMPARE_SCHEDULE = getenv("COMPARE_SCHEDULE", 1)
if REF == "master": SKIP_PROCESS_REPLAY = True
# *** differs
@@ -43,7 +43,7 @@ def diff_kernel(offset:int) -> bool:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
cur.execute(f"SELECT val FROM 'kernel_{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset))
changed = 0
for row in cur.fetchall():
# try unpickle
@@ -54,7 +54,7 @@ def diff_kernel(offset:int) -> bool:
continue
# try linearize
try:
with Context(**{k:v for k,v in ctx.items() if k in ContextVar._cache and k != "DEBUG"}):
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
@@ -72,6 +72,7 @@ def diff_kernel(offset:int) -> bool:
logging.info("PROCESS REPLAY DETECTED CHANGE")
logging.info(ast)
logging.info(applied_opts)
logging.info(ctx.loc)
print_diff(good_src, compare_src)
if ASSERT_DIFF: return True
if changed > MAX_DIFF_PCT:
@@ -112,9 +113,9 @@ def process_replay_schedule() -> None:
def process_replay_kernel() -> None:
conn = db_connection()
cur = conn.cursor()
try: row_count = cur.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
try: row_count = cur.execute(f"select count(*) from 'kernel_{TABLE_NAME}'").fetchone()[0]
except sqlite3.OperationalError:
logging.warning(f"{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
logging.warning(f"kernel_{TABLE_NAME} isn't accessible in master, did DB_VERSION change?")
return None
conn.commit()
cur.close()
@@ -127,11 +128,12 @@ if __name__ == "__main__":
logging.info("skipping process replay.")
exit(0)
logging.info("***** schedule diff")
try: process_replay_schedule()
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"schedule diff err {e}")
if COMPARE_SCHEDULE:
logging.info("***** schedule diff")
try: process_replay_schedule()
except Exception as e:
if ASSERT_DIFF: raise e
logging.error(f"schedule diff err {e}")
logging.info("***** kernel diff")
try: process_replay_kernel()

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python3
from tinygrad.helpers import db_connection, VERSION, getenv, os
from tinygrad.helpers import db_connection, VERSION, os
cur = db_connection()
cur.execute(f"drop table if exists process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}_{VERSION}")
cur.execute(f"drop table if exists schedule_diff_{VERSION}")
cur.execute(f"drop table if exists kernel_process_replay_{VERSION}")
cur.execute(f"drop table if exists schedule_process_replay_{VERSION}")
if os.path.exists(fp:=__file__.replace("reset", "master_schedule")):
os.system(f"rm -rf {fp}")

View File

@@ -1,22 +0,0 @@
# restore a specific benchmark process replay
import pickle, os
from tinygrad.device import Device
from tinygrad.helpers import db_connection, VERSION, tqdm
cur = db_connection()
RUN_ID = os.environ["GITHUB_RUN_ID"]
ATTEMPT = os.environ["GITHUB_RUN_ATTEMPT"]
TABLE_NAME = f"process_replay_{RUN_ID}_{ATTEMPT}_{VERSION}"
PAGE_SIZE = 100
row_cnt = cur.execute(f"select count(*) from {TABLE_NAME}").fetchone()[0]
for offset in tqdm(range(0, row_cnt, PAGE_SIZE)):
rows = cur.execute(f"SELECT val FROM '{TABLE_NAME}' LIMIT ? OFFSET ?", (PAGE_SIZE, offset)).fetchall()
for row in rows:
ast, opts, applied_opts, name, compare_src, ctx = pickle.loads(row[0])
try: Device[opts.device].compiler.compile(compare_src)
except Exception:
print("FAILED TO COMPILE")
print(ast)
print(applied_opts)
print(compare_src)
continue

View File

@@ -1,4 +1,6 @@
import unittest
import contextlib, sqlite3
from test.external.process_replay.helpers import ProcessReplayContext
from test.external.process_replay.process_replay import TABLE_NAME, diff_kernel
from tinygrad.codegen.kernel import Kernel
@@ -8,16 +10,17 @@ from tinygrad.renderer.cstyle import ClangRenderer
from tinygrad.tensor import Tensor
def helper_append_replay(ast:UOp, name:str, src:str) -> int:
diskcache_put(TABLE_NAME.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, {}))
name = f"kernel_{TABLE_NAME}"
diskcache_put(name.replace(f"_{VERSION}", ""), "test_1", (ast, ClangRenderer(), [], to_function_name(name), src, ProcessReplayContext({})))
conn = db_connection()
row_count = conn.execute(f"select count(*) from '{TABLE_NAME}'").fetchone()[0]
row_count = conn.execute(f"select count(*) from '{name}'").fetchone()[0]
return row_count
class TestProcessReplay(unittest.TestCase):
def tearDown(self):
conn = db_connection()
cur = conn.cursor()
cur.execute(f"DELETE FROM '{TABLE_NAME}' WHERE key LIKE 'test_%'")
with contextlib.suppress(sqlite3.OperationalError): cur.execute(f"DELETE FROM 'kernel_{TABLE_NAME}' WHERE key LIKE 'test_%'")
conn.commit()
cur.close()

View File

@@ -65,7 +65,7 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None, st_src:Optional[Tuple[UOp]]=None) -> UOp:
if st_src is None:
st_src = (st.to_uop() if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)
return UOp(UOps.CONST, dtype, st_src, dtypes.as_const(val, dtype))
return UOp(UOps.VALID, dtypes.bool, st_src).where(UOp.const(dtype, val), UOp.const(dtype, 0))
T = TypeVar("T")
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:

View File

@@ -876,7 +876,7 @@ class TestLinearizer(unittest.TestCase):
lin = helper_linearizer_opt(out, wanna_output=[(a.numpy()+b.numpy()[0]).sum()+b.numpy()])[0]
ranges = [i for i,u in enumerate(lin.uops) if u.op is UOps.RANGE]
# LOAD -> RANGE -> LOAD -> ASSIGN
assert lin.uops[ranges[0]-2].op is UOps.LOAD
assert len([x for x in lin.uops[:ranges[0]] if x.op is UOps.LOAD]) == 1
def test_range_outer_op_before_phi_nested_range(self):
a = Tensor.randn(2, ).realize()
@@ -1715,7 +1715,7 @@ class TestHandCodedOpts(unittest.TestCase):
k = Kernel(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end)
upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len]))
assert len(upcasts) == 3 # 3 transformation matrices
assert len(wino_schedule) <= 4 # 4 kernels

View File

@@ -572,13 +572,13 @@ class TestMultiTensor(unittest.TestCase):
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD and ast.src[2].src[0]
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 1
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 1
t = 2 * t
for si in t.schedule():
ast = si.ast.src[0]
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.MUL
assert ast.src[2].src[0].op is UOps.CONST and ast.src[2].src[0].arg == 2
assert ast.src[2].src[0].src[1].op is UOps.CONST and ast.src[2].src[0].src[1].arg == 2
assert ast.src[2].src[1].op is UOps.LOAD
t = t + t.full_like(3)
for si in t.schedule():
@@ -586,7 +586,7 @@ class TestMultiTensor(unittest.TestCase):
assert ast.op is UOps.STORE
assert ast.src[2].arg is BinaryOps.ADD
assert ast.src[2].src[0].op is UOps.LOAD
assert ast.src[2].src[1].op is UOps.CONST and ast.src[2].src[1].arg == 3
assert ast.src[2].src[1].src[1].op is UOps.CONST and ast.src[2].src[1].src[1].arg == 3
def test_shard_memory(self):
devices = (d0, d1, d2, d3)

View File

@@ -889,6 +889,8 @@ class TestOps(unittest.TestCase):
def test_sum_simple(self):
helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]])
# NOTE: simple test for locals
# FORWARD_ONLY=1 DEBUG=4 python3 test/test_ops.py TestOps.test_sum_full
def test_sum_full(self):
helper_test_op([(16384)], lambda x: x.sum())
def test_sum_relu(self):

View File

@@ -2,7 +2,12 @@ import unittest, pickle
import numpy as np
from test.helpers import assert_equiv_uops
from tinygrad import Tensor, TinyJit, Variable
from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import PtrDType, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.ops import BinaryOps, TernaryOps, UOp, UOps, UnaryOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
@@ -66,6 +71,32 @@ class TestPickle(unittest.TestCase):
sched_pk = pickle.loads(pk)
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
def test_pickle_define_var(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
x2:=UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(Variable('i', 1, 10), 3), strides=(3, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=(
UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12:=UOp(UOps.VALID, dtypes.bool, arg=None, src=(
x2,)),
UOp.define_var("i", dtypes.int, 1, 10),
x14:=UOp(UOps.CONST, dtypes.int, arg=0, src=()),)),
UOp(UOps.ALU, dtypes.int, arg=TernaryOps.WHERE, src=(
x12,
UOp(UOps.CONST, dtypes.int, arg=3, src=()),
x14,)),)),)),)),)),)),))
p = Kernel(ast).to_program(name_override="test")
ps = Kernel(pickle.loads(pickle.dumps(ast))).to_program(name_override="test")
self.assertEqual(ps.src, p.src)
class TestPickleJIT(unittest.TestCase):
@classmethod
def setUpClass(cls):

View File

@@ -90,12 +90,10 @@ class TestRandomness(unittest.TestCase):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
assert x.dtype == dtypes.bfloat16
# TODO: fix this property for bfloat16 random
# x = x.numpy()
# ones = np.take(x, np.where(x == 1))
# zeros = np.take(x, np.where(x == 0))
# self.assertTrue(ones.size == 0)
# self.assertTrue(zeros.size > 0)
if THREEFRY.value:
nx = x.numpy()
assert nx[nx == 1].size == 0
assert nx[nx == 0].size > 0
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
def test_randn(self):

View File

@@ -162,10 +162,10 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
a = UOp.define_var('a', dtypes.int, 0, 1)
b = UOp.define_var('b', dtypes.int, 0, 1)
c = UOp.define_var('c', dtypes.int, 0, 1)
d = UOp.define_var('d', dtypes.int, 0, 1)
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, constant_folder)
@@ -186,7 +186,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
v = UOp.define_var('tmp', dtypes.int, 0, 1)
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@@ -280,7 +280,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@@ -289,7 +289,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@@ -356,7 +356,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
v = UOp.define_var("tmp", dtypes.int, 0, 1)
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
@@ -385,7 +385,7 @@ class TestUOpGraph(unittest.TestCase):
def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int), (), ("temp", 1))
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
@@ -610,8 +610,8 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate = UOp.define_var("g1", dtypes.bool, False, True)
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink)
@@ -626,7 +626,7 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate = UOp.define_var("g1", dtypes.bool, False, True)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
@@ -637,8 +637,8 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate = UOp.define_var("g1", dtypes.bool, False, True)
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink)
@@ -650,7 +650,7 @@ def gate_rewrite(sink): return graph_rewrite(sink, constant_folder + expander +
class TestIFUOps(unittest.TestCase):
def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 4))
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2))
@@ -669,7 +669,7 @@ class TestIFUOps(unittest.TestCase):
def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float), (), ("smem", 16))
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2))

View File

@@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), ('smem', 16))
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16))
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
@@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self):
uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32), (), ('smem', 16))
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16))
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))

View File

@@ -19,7 +19,9 @@ def render(image_shape, valid:UOp, idx:UOp) -> str:
return fxn.split("float4 val0 = ")[1].split(";")[0]
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
def Variable(expr, nmin, nmax): return UOp(UOps.DEFINE_VAR, dtypes.int, (), (expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax)
def Range(n, nmax):
return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
class TestHelpers(unittest.TestCase):
def test_is_increasing(self):
@@ -97,46 +99,69 @@ class TestValidSimplification(unittest.TestCase):
# empty
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx))
@unittest.expectedFailure # TODO: FIXME
def test_openpilot_conv1(self):
# first conv in openpilot
# kernel in tinygrad ae5d1407ee844a97a52ad3756835d38e7e2b9e1b https://gist.github.com/chenyuxyz/39c2d4e9a076b46731c67d345ff066b6
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 5)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# ridx0 = Variable("ridx0", 0, 5)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 6)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu4 = ((idx1*48)+(ridx2*6)+ridx0)
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
valid = (((idx2*2)+(ridx1)).lt(1).ne(True))&(((idx1*8)+(ridx2)).lt(1).ne(True))
shape = (128, 1536, 4)
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
# (((((idx2*(-2))+(ridx1*(-1)))<0)&(((idx1*(-8))+(ridx2*(-1)))<0))?read_imagef(data0, smp,
# (int2)((((idx1*48)+(ridx2*6)+ridx0+1530)%1536),((idx2*2)+ridx1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))):(float4)(0.0f,0.0f,0.0f,0.0f))
self.assertEqual(render(shape, valid, idx),
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-6)),((idx2*2)+ridx1+(-1))))")
"read_imagef(data0, smp, (int2)(((idx1*48)+(ridx2*6)+ridx0+(-6)),((idx2*2)+ridx1+(-1))))")
@unittest.expectedFailure # TODO: FIXME
def test_openpilot_conv2(self):
# conv in test/external/external_test_valid_remove.py
idx1 = Special("idx1", 32)
idx2 = Special("idx2", 64)
ridx0 = Variable("ridx0", 0, 2)
ridx1 = Variable("ridx1", 0, 2)
ridx2 = Variable("ridx2", 0, 2)
# ridx0 = Variable("ridx0", 0, 2)
# ridx1 = Variable("ridx1", 0, 2)
# ridx2 = Variable("ridx2", 0, 2)
ridx0 = Range(0, 3)
ridx1 = Range(1, 3)
ridx2 = Range(2, 3)
alu1 = ((idx2*2)+ridx1)
alu3 = ((idx1*24)+(ridx2*3)+ridx0)
valid = (((idx2*(-2))+(ridx1*(-1))).lt(0))&(((idx1*(-8))+(ridx2*(-1))).lt(0))
valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True))
shape = (128, 768, 4)
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
self.assertEqual(render(shape, valid, idx),
"read_imagef(data1, smp, (int2)((ridx0+(idx1*48)+(ridx2*6)+(-3)),((idx2*2)+ridx1+(-1))))")
"read_imagef(data0, smp, (int2)(((idx1*24)+(ridx2*3)+ridx0+(-3)),((idx2*2)+ridx1+(-1))))")
def test_openpilot_conv3(self):
# in openpilot 0.9.7
idx0 = Special("idx0", 64)
idx1 = Special("idx1", 2)
idx2 = Special("idx2", 4)
ridx0 = Range(0, 7)
ridx1 = Range(1, 7)
alu2 = ((idx2*2)+ridx0)
alu4 = ((idx1*8)+ridx1)
alu6 = ((idx1*512)+(ridx1*64)+idx0)
valid = alu2.lt(11)&(alu4.lt(3).ne(True))
shape = (8, 1024, 4)
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)/8)+1)/2)+(-4))))
# TODO: simplify idx
# alu0 = ((idx2*2)+ridx0)
self.assertEqual(render(shape, valid, idx),
"(((alu0<11)&((((idx1*8)+ridx1)<3)!=1))?read_imagef(data0, smp, (int2)((((idx1*512)+(ridx1*64)+idx0+832)%1024),(alu0+(-4)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
def test_simplify1(self):
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
@@ -188,5 +213,20 @@ class TestValidSimplification(unittest.TestCase):
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))),
"((idx0<256)?read_imagef(data0, smp, (int2)((((((idx0//8)%32)//4)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
def test_simplify5(self):
# openpilot 0.9.7, chunk replacement to simplify
shape = (10, 384, 4)
idx0 = Special("idx0", 16)
idx1 = Special("idx1", 24)
alu0 = idx0*4
alu1 = (idx1*256)+alu0
alu2 = idx1//3
alu3 = ((alu1+1)%768)
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
valid = alu3.lt(640)
self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
if __name__ == '__main__':
unittest.main()

View File

@@ -11,6 +11,20 @@ class TestPatternMatcher(unittest.TestCase):
self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None)
def test_match_sz_0(self):
match_cnt = 0
def fxn(x):
nonlocal match_cnt
match_cnt += 1
assert len(x.src) == 0
return UOp(UOps.CONST, src=(UOp(UOps.CONST),))
matcher = PatternMatcher([(UPat(UOps.CONST, src=(), name="x"), fxn)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
# second rewrite shouldn't match anything
c1 = matcher.rewrite(c1)
c1 = matcher.rewrite(c1)
self.assertEqual(match_cnt, 1)
def test_uop(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x"), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)

View File

@@ -118,7 +118,9 @@ class TestRealDoesntSimplify(unittest.TestCase):
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1))
class TestRealStrides(unittest.TestCase):
@unittest.expectedFailure
def test_1(self):
# TODO: find the correct rewrite rule to fix this
self.st = ShapeTracker((
View.create((2048,), (1,), 0, ((0, 512),)),
View.create((16, 32, 4), (128, 4, 1), 0, None)))
@@ -489,6 +491,14 @@ class TestComplexShapeTracker(unittest.TestCase):
print(self.st.views)
assert self.st.contiguous
class TestShapeTrackerEquality(unittest.TestCase):
def test_simple_equals(self):
self.assertEqual(ShapeTracker.from_shape((10,10)), ShapeTracker.from_shape((10,10)))
def test_other_equals(self):
st1 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True)))
st2 = ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True)))
self.assertEqual(st1, st2)
class TestSingleShapeTracker(unittest.TestCase):
def setUp(self):
self.st = CheckingShapeTracker((7,4))

View File

@@ -27,9 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax):
vmin = UOp.const(dtypes.int, nmin)
vmax = UOp.const(dtypes.int, nmax) if isinstance(nmax, int) else nmax
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, vmin, vmax))
return UOp.define_var(expr, dtypes.int, nmin, nmax if isinstance(nmax, int) else nmax.arg)
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@@ -450,6 +448,19 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((idx//4).lt(3), 0, 1, "(idx<12)")
self.helper_test_variable((idx//-4).lt(-3), 0, 1, "((idx//(-4))<(-3))")
def test_simplex_lt(self):
a = Variable("a", 0, 3)
b = Variable("b", 0, 3)
c = Variable("c", 0, 3)
d = Variable("d", -3, 3)
self.helper_test_variable((a).lt(1).ne(True), 0, 1, "((a<1)!=1)")
self.helper_test_variable((a+b).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*3+b*4).lt(1).ne(True), 0, 1, "(((a+b)<1)!=1)")
self.helper_test_variable((a*(-3)+b*4).lt(1).ne(True), 0, 1, "((((a*(-3))+(b*4))<1)!=1)") # negative coeff, should not be simplified
self.helper_test_variable((a*3+d*4).lt(1).ne(True), 0, 1, "((((a*3)+(d*4))<1)!=1)") # var can be negative, should not be simplified
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):

View File

@@ -13,7 +13,7 @@ from tinygrad.shape.view import View
class InvalidASTException(Exception): pass
def helper_test_verify_ast(*stores:UOp) -> Kernel:
sink = UOp(UOps.SINK, None, stores)
sink = UOp(UOps.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: verify_ast(sink)
@@ -50,7 +50,7 @@ class TestVerifyAST(unittest.TestCase):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_shrink_ok(self):
@@ -79,7 +79,7 @@ class TestVerifyAST(unittest.TestCase):
uop_sts = verify_ast(a.schedule()[-1].ast)
store_st = [st for u,st in uop_sts.items() if u.op is UOps.STORE][0]
self.assertEqual(store_st, ShapeTracker.from_shape((4, 4)))
const_st = [st for u,st in uop_sts.items() if u.op is UOps.CONST][0]
const_st = [st for u,st in uop_sts.items() if u.op is UOps.VALID][0]
self.assertEqual(const_st, ShapeTracker.from_shape((1, 1)).expand((4, 4)))
if __name__ == '__main__':

View File

@@ -4,12 +4,13 @@ from dataclasses import dataclass
from collections import defaultdict
from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify, \
graph_rewrite, PatternMatcher
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.helpers import _CURRENT_KERNEL, all_same, colored, ansilen, dedup, getenv, prod, DEBUG, TC_OPT, USE_TC, AMX, round_up, all_int, \
get_contraction, to_function_name, diskcache_put, ContextVar
get_contraction, to_function_name, diskcache_put
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.shape.view import strides_for_shape
@@ -640,7 +641,7 @@ class Kernel:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
if op.op is UOps.CONST: return op.replace(src=(st_uop,))
if op.op is UOps.VALID: return op.replace(src=(st_uop,))
if op.op is UOps.STORE: return op.replace(src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return op.replace(src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.REDUCE_AXIS:
@@ -702,7 +703,7 @@ class Kernel:
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
else:
@@ -711,7 +712,14 @@ class Kernel:
# MUL/SUM instead of WMMA
ret = UOp(UOps.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1]))
else:
ret = UOp(UOps.WMMA, tc.dtype_out, (fixup_ast(rsrc.src[0], fix_st1), fixup_ast(rsrc.src[1], fix_st2)), wmma_arg)
# real WMMA, use CONTRACT/EXPAND to get the vectorization right
wmma_upcast_axes = wmma_arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in wmma_upcast_axes]
wmma = UOp(UOps.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=rsrc.src[0].dtype.vec(wmma_sz[0]), src=(fixup_ast(rsrc.src[0], fix_st1),), arg=wmma_upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=rsrc.src[1].dtype.vec(wmma_sz[1]), src=(fixup_ast(rsrc.src[1], fix_st2),), arg=wmma_upcast_axes[1]),
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
ret = UOp(UOps.EXPAND, tc.dtype_out, (wmma,), arg=wmma_upcast_axes[2])
new_reduce_axes = tuple(i for i in axis if i-self.first_upcast not in reduce_axes)
return op.replace(src=(ret,), arg=(alu_op, new_reduce_axes)) if new_reduce_axes else ret
if self.group_for_reduces:
@@ -725,7 +733,7 @@ class Kernel:
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
if op is self.reduceops[-1]: return grouped_reduce
@@ -735,7 +743,8 @@ class Kernel:
elif op.op is UOps.SINK:
arg = KernelInfo(self.local_dims, self.upcasted, self.dont_use_locals)
return op.replace(src=tuple(fixup_ast(x, apply_to_st) for x in op.src), arg=arg)
return fixup_ast(self.ast)
# NOTE: rewrite with an empty PatternMatcher to dedup UOps
return graph_rewrite(fixup_ast(self.ast), PatternMatcher([]))
# **** this is the lowerer ****
@@ -763,8 +772,8 @@ class Kernel:
src = self.opts.render(name:=to_function_name(ansiname:=(name_override if name_override is not None else self.name)), self.uops)
if getenv("RUN_PROCESS_REPLAY"):
table_name = f"process_replay_{getenv('GITHUB_RUN_ID', 'HEAD')}_{getenv('GITHUB_RUN_ATTEMPT')}"
diskcache_put(table_name, str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, {k:v.value for k,v in ContextVar._cache.items()}))
from test.external.process_replay.helpers import get_process_replay_ctx
diskcache_put("kernel_process_replay", str(id(self)), (self.ast, self.opts, self.applied_opts, name, src, get_process_replay_ctx()))
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
@@ -778,30 +787,28 @@ class Kernel:
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if not uop.has_st or uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
# restore globals from the two stage reduce
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]
return
for x in src: _assert_valid_uop(x, st, sts)
for x in uop.src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
elif op is UOps.SWIZZLE: st = arg
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1]))
# movementops are pushed to SHAPETRACKER and SWIZZLE
elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg
# everything else inherits shape
else:
assert op in {UOps.SHAPETRACKER, UOps.SWIZZLE, UOps.ALU, UOps.CAST, UOps.BITCAST, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = arg if op is UOps.SHAPETRACKER else sts[src[uop.st_loc if op in BUFFER_UOPS else 0]]
for x in src:
if x.has_st and sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op} {sts[x].shape} != {st.shape}")
raise AssertionError(f"found implicit expand {x.op} {sts[x].shape} != {op} {st.shape} {prod(sts[x].shape)} != {prod(st.shape)}")
assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, UOps.ASSIGN, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
if not all_same(shapes:=[x.shape for x in src_sts]):
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
raise AssertionError(f"found implicit expand {sizes}")
sts[uop] = st
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
sts: Dict[UOp, ShapeTracker] = {}
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]

View File

@@ -1,13 +1,17 @@
# the job of the lowerer is to do indexing
from __future__ import annotations
import functools
from typing import List, Tuple, cast, Optional, Dict
from dataclasses import dataclass
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
# ***** indexing *****
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
@@ -34,99 +38,91 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
idx //= dims[c]
return ret[::-1] if reverse else ret
class IndependentLowerer:
def lower(self, ast:UOp, opts:Renderer) -> UOp:
self.output_count = len(ast.src)
@dataclass(frozen=True)
class IndexContext:
idxs: List[UOp]
ridxs: List[UOp]
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
first_output_st: ShapeTracker = ast.src[0].st_arg
# if there's no reduce, this is first_upcasted
first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
global_dims = first_reduce-ki.local_dims
def get_index(ast:UOp, opts:Renderer) -> IndexContext:
ki = ast.arg if isinstance(ast.arg, KernelInfo) else KernelInfo()
# NOTE: assumes the shape is <global dims> <local dims> <group_for_reduces> <reduces> <upcasts/unrolls>
full_shape = ast.full_shape
first_upcasted = len(full_shape)-ki.upcasted
first_output_st: ShapeTracker = ast.src[0].st_arg
# if there's no reduce, this is first_upcasted
first_reduce = [x!=y for x,y in zip(first_output_st.shape[:first_upcasted]+(0,), full_shape[:first_upcasted]+(1,))].index(True)
local_loads = [x for x in ast.parents if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL]
# NOTE: sum up the reduced axes looking across all local loads, yields the number of grouped reduces
group_for_reduces = sum([any(j!=y for j in x) for x,y in zip(
[[l.st_arg.shape[i] for l in local_loads] for i in range(first_reduce,first_upcasted)],
first_output_st.shape[first_reduce:first_upcasted])]) if local_loads else 0
global_dims = first_reduce-ki.local_dims
if opts.has_local:
if ki.dont_use_locals:
assert ki.local_dims == 0, "can't use locals if there's no local dims"
self.idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
else:
# define indexes for GPU-like execution
self.idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
if opts.has_local:
if ki.dont_use_locals:
assert ki.local_dims == 0, "can't use locals if there's no local dims"
idxs = get_grouped_dims("idx", full_shape[:global_dims], opts.global_max, reverse=True)
else:
# all loops are RANGES
self.idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]
# define indexes for GPU-like execution
idxs = get_grouped_dims("gidx", full_shape[:global_dims], opts.global_max, reverse=True) + \
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
else:
# all loops are RANGES
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
for i,g in enumerate(full_shape[:first_reduce])]
# reduce loops
self.idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# reduce loops
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
self.idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
# upcast loops
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
assert isinstance(g, int), "needs to be int to upcast/unroll"
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
# late indexes (group for reduce)
self.ridxs = self.idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
# late indexes (group for reduce)
ridxs = idxs[:]
for a in range(first_reduce, first_reduce+group_for_reduces):
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
self.uop_cache: Dict[UOp, UOp] = {}
return self.to_uop(ast)
return IndexContext(idxs, ridxs)
def to_uop(self, x:UOp) -> UOp:
if uop:=self.uop_cache.get(x, None): return uop
ret = self._to_uop(x)
self.uop_cache[x] = ret
return ret
# ***** lowering (given index) *****
def _to_uop(self, x:UOp) -> UOp:
if x.op in BUFFER_UOPS:
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is UOps.CONST: return valid.where(x.const_like(x.arg), x.const_like(0))
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, dtypes.void, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(self.idxs, self.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
return UOp(UOps.STORE, dtypes.void, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
def lower_reduce_axis(ctx: IndexContext, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
alu_op: BinaryOps = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
in_uops = tuple(self.to_uop(y) for y in x.src)
if x.op is UOps.WMMA:
upcast_axes = x.arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
if x.op is UOps.REDUCE_AXIS:
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
alu_op: BinaryOps = x.arg[0]
ret = in_uops[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
return x if x.src == in_uops else UOp(x.op, x.dtype, in_uops, x.arg)
def lower_load_store(ctx: IndexContext, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
pm_lowerer = PatternMatcher([
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.SHAPETRACKER)), allow_any_len=True, name="x"), lower_load_store),
])
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))

View File

@@ -143,7 +143,6 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
if (newx:=div_folding(x,c)) is not None and newx.op is UOps.ALU and newx.arg is BinaryOps.IDIV: return newx.src[0].lt(newx.src[1])
return cast(UOp, x.divides(g)).lt(c//g) if ((g:=math.gcd(x.const_factor(), c)) > 1) else None
def fold_unrolled_divs(divs:UOp):
@@ -163,16 +162,31 @@ def fold_unrolled_divs(divs:UOp):
# ***** image load valid simplification *****
def is_irreducible(u:UOp): return u.op in (UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE)
def canonicalize_simplex(X:UOp) -> Optional[UOp]:
# (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints.
# returns x0 + x1 + ... in such case, or None if not
changed, ret = False, []
for u in _get_chain(X, BinaryOps.ADD):
# assumed the const is the last src of MUL
if u.op is UOps.ALU and u.arg is BinaryOps.MUL and u.src[1].op is UOps.CONST and u.src[1].arg > 0:
changed = True
u = u.src[0]
if not (is_irreducible(u) and u.vmin >= 0): return None
ret.append(u)
return functools.reduce(operator.add, ret) if changed else None
def is_increasing(f:UOp):
# is f a monotonically increasing function regards its input
if f.op in [UOps.CONST, UOps.DEFINE_VAR, UOps.SPECIAL, UOps.RANGE]: return True
if f.op is UOps.CONST or is_irreducible(f): return True
if f.op is UOps.ALU and f.arg is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1])
if f.op is UOps.ALU and f.arg in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is UOps.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0])
return False # False if not sure
def replace_uop(uop:UOp, old:UOp, new:UOp):
# replace all `old` in `uop` to `new`
return new if uop is old else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
return new if uop.key == old.key else UOp(uop.op, uop.dtype, tuple(replace_uop(s, old, new) for s in uop.src), uop.arg)
def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]:
# if it's X <= c, returns X, True, c
@@ -196,30 +210,50 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
expr, is_upper, c = parse_valid(stmt)
bounds[expr][int(is_upper)] = c
# simplify idx given that valid is True
for uop,v in bounds.items():
# some expr has lower bound > upper bound -> valid is an empty set
if v[0] is not None and v[1] is not None and v[0] > v[1]:
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
bound = uop.const_like(uop.vmin if v[0] is None else v[0]), uop.const_like(uop.vmax if v[1] is None else v[1])
new = UOp(UOps.DEFINE_VAR, uop.dtype, (), ("fake", bound[0], bound[1]))
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
if newidx.key != idx.key: idx = newidx
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
newidxs: List[List[UOp]] = [[], []]
for variable in _get_chain(uop, BinaryOps.ADD):
new = UOp(UOps.DEFINE_VAR, variable.dtype, (), ("fake", 1, variable.vmax))
newidx = replace_uop(graph_rewrite(replace_uop(idx, variable, new), constant_folder), new, variable)
newidxs[0].append(newidx.src[0])
newidxs[1].append(newidx.src[1])
if len(newidxs[0])==1 or (len(newidxs[0]) > 1 and all_same([i.key for i in newidxs[0]])): idx = idx.replace(src=(newidxs[0][0], idx.src[1]))
if len(newidxs[1])==1 or (len(newidxs[1]) > 1 and all_same([i.key for i in newidxs[1]])): idx = idx.replace(src=(idx.src[0], newidxs[1][0]))
else:
new = UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1])
newidx = replace_uop(graph_rewrite(replace_uop(idx, uop, new), constant_folder), new, uop)
if newidx.key != idx.key: idx = newidx
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in _get_chain(valid, BinaryOps.AND):
X, is_upper, c = parse_valid(stmt)
if is_upper:
# X <= c, check if it's out of bound when X = c+1
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c+1)), constant_folder).vmin >= b:
drop_stmt.append(stmt)
break
else:
# X >= c, check if it's negative when X = c-1
for i in idx.src:
if is_increasing(i) and graph_rewrite(replace_uop(i, X, X.const_like(c-1)), constant_folder).vmax < 0:
drop_stmt.append(stmt)
break
X, is_upper_bound, c = parse_valid(stmt)
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and X.op is UOps.ALU and X.arg is BinaryOps.ADD and \
all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(X, BinaryOps.ADD)):
testidx = functools.reduce(lambda nowidx,u: replace_uop(nowidx, u, u.const_like(0)), _get_chain(X, BinaryOps.ADD), idx)
testidx = graph_rewrite(testidx, constant_folder)
if testidx.src[0].vmax < 0 or testidx.src[1].vmax < 0:
drop_stmt.append(stmt)
continue
# if X <= c, check if it's out of bound when X = c+1
# if X >= c, check if it's out of bound when X = c-1
test_value = c + 1 if is_upper_bound else c - 1
for i,b in zip(idx.src, (buf_dtype.shape[1], buf_dtype.shape[0])):
if is_increasing(i):
rw = graph_rewrite(replace_uop(i, X, X.const_like(test_value)), constant_folder)
if rw.vmin >= b or rw.vmax < 0: drop_stmt.append(stmt)
if drop_stmt or idx.key != start_idx.key:
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s not in drop_stmt]) else None
@@ -312,6 +346,8 @@ constant_folder = PatternMatcher([
(UPat(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, name="x"), lambda x: UOp(x.op, x.dtype, x.src, BinaryOps.AND)),
# self ASSIGN is just self
(UPat(UOps.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/GEP: the expander rule allows tuple GEP creation, this is just for removal
(UPat(UOps.VECTORIZE, src=UPat(UOps.GEP, src=(UPat(name="x"),)), name="vec"),
lambda vec,x: x if x.dtype == vec.dtype and tuple(y.arg[0] for y in vec.src) == tuple(range(len(vec.src))) else None),
@@ -419,6 +455,9 @@ constant_folder = PatternMatcher([
# generic lt folding
(UPat.var("x").lt(UPat.cvar("c", vec=False)),
lambda x,c: lt_folding(x, c.arg) if 0 < c.arg and dtypes.is_int(x.dtype) and not dtypes.is_unsigned(x.dtype) else None),
# canonicalize a simplex with positive coefficients > 0
# not x < 1 -> X > 0
(UPat.var("x").lt(1).ne(True), lambda x: newx.lt(1).ne(True) if dtypes.is_int(x.dtype) and (newx:=canonicalize_simplex(x)) is not None else None),
# ** div **
# # div folding
(UPat.var("x") // UPat.cvar("c", vec=False), lambda x,c:
@@ -466,6 +505,9 @@ constant_folder = PatternMatcher([
# ** move add consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x+c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
(UPat(UOps.ALU, arg=BinaryOps.ADD, src=(UPat.var("x"), UPat.cvar("c1"))) + UPat.var("y"), lambda x,c1,y: (x+y)+c1),
# ** move mul consts to end (NOTE: this is still happening before constant folding) **
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.cvar("c1"), UPat.var("x"))), lambda c1,x: x*c1 if x.op not in (UOps.CONST, UOps.VCONST) else None),
(UPat(UOps.ALU, arg=BinaryOps.MUL, src=(UPat.var("x"), UPat.cvar("c1"))) * UPat.var("y"), lambda x,c1,y: (x*y)*c1),
])
# *** uop expander ***
@@ -706,8 +748,8 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
scope_children = {p:get_recursive_children(p, END_FOR_UOP[p.op][0]) for p in reversed(in_degree) if p.op in END_FOR_UOP}
range_phi = {r:[p for p in scope_children[r] if p.op is UOps.ASSIGN] for r in scope_children if r.op is UOps.RANGE}
queue:List[Tuple[int, UOp]] = []
def push(u:UOp):
# assign priorities
def get_priority(u:UOp):
priority = 0
# prefer ranges that depend on the least number of independent ranges
if u.op is UOps.RANGE and u.arg[1]:
@@ -717,7 +759,20 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
# prefer uops that are loop children
else:
priority -= sum([(l.arg[0]+1) + 1000*l.arg[1] for l,ss in scope_children.items() if l.op is UOps.RANGE and u in ss])
heapq.heappush(queue, (priority, u))
return priority
priorities:Dict[UOp, int] = {u:get_priority(u) for u in children}
# prevent priority inversion
@functools.lru_cache(None)
def fix_priority(u:UOp, lowest_priority):
if u.op in {UOps.CAST, UOps.BITCAST, UOps.ALU, UOps.VECTORIZE, UOps.GEP, UOps.SPECIAL, UOps.DEFINE_LOCAL, UOps.LOAD}:
priorities[u] = min(priorities[u], lowest_priority)
if u.op is UOps.LOAD: priorities[u] += 100 # load penalty (here)
for x in u.src: fix_priority(x, priorities[u])
fix_priority(sink, 0)
queue:List[Tuple[int, UOp]] = []
def push(u:UOp): heapq.heappush(queue, (priorities[u], u))
for u in children:
if in_degree[u] == 0: push(u)
@@ -726,7 +781,7 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
_uops: List[UOp] = []
while queue:
p,x = heapq.heappop(queue)
if DEBUG >= 7: print(f"{p:5d}",x)
if DEBUG >= 7: print(f"{p:5d}", x.op, x.dtype, x.arg)
if x in scope_children: scope_end[x] = x
if x.op is UOps.DEFINE_ACC:
idx = min([_uops.index(l) for l in x.src if l.op is UOps.RANGE])
@@ -745,12 +800,9 @@ def linearize_uop(sink:UOp, skip_check:bool=not __debug__) -> List[UOp]:
# sanity checks (NOTE: these can cause things to be skipped in BEAM)
if not skip_check:
bad_ops = dedup([x.op for x in _uops if x.op in {UOps.EXPAND, UOps.CONTRACT, UOps.REDUCE, UOps.REDUCE_AXIS, UOps.SHAPETRACKER}])
try:
type_verify(_uops)
assert _uops[-1].op is UOps.SINK, f"didn't end with SINK, ended with {_uops[-1]}"
assert len(bad_ops) == 0, f"bad UOps left in list: {bad_ops}"
assert not any(x.dtype == dtypes.pyint for x in _uops), "can't return UOp with pyint"
# TODO: this should be enabled, and the valid clause should be removed
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
# NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL

View File

@@ -108,7 +108,8 @@ class Buffer:
(">" if self.options is None else f" {self.options=}>")
def as_buffer(self, allow_zero_copy=False, force_zero_copy=False) -> memoryview:
# zero copy with as_buffer (disabled by default due to use after free)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer'): return self.allocator.as_buffer(self._buf)
if (force_zero_copy or allow_zero_copy) and hasattr(self.allocator, 'as_buffer') and (self.options is None or self.options.image is None):
return self.allocator.as_buffer(self._buf)
assert not force_zero_copy, "force zero copy was passed, but copy is required"
return self.copyout(memoryview(bytearray(self.nbytes)))
def copyin(self, mv:memoryview):

View File

@@ -25,17 +25,20 @@ class DType:
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
local: bool = False # images are never local
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __init__(self, dt:DType, local=False):
self.base, self.local = dt, local
super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"PtrDType({super().__repr__()})"
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
class dtypes:
@staticmethod

View File

@@ -181,7 +181,7 @@ class ExecItem:
lds_est = sym_infer(self.prg.lds_estimate, var_vals)
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
ptm = (colored(f"{et*1e3:9.2f}ms", "yellow") if et > 0.01 else f"{et*1e6:9.2f}us") if et is not None else ""
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
print(f"{colored(f'*** {self.prg.dname[:7]:7s} {GlobalCounters.kernel_count:4d}', 'magenta' if jit else ('green' if self.prg.first_run else None))} {self.prg.display_name+' '*(40-ansilen(self.prg.display_name))} arg {len(bufs):2d} mem {GlobalCounters.mem_used/1e9:5.2f} GB " + # noqa: E501
(str() if et is None else f"tm {ptm}/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_est/((et or 1e-20)*1e9):9.2f} GFLOPS {mem_est/((et or 1e-20)*1e9):6.1f}|{lds_est/((et or 1e-20)*1e9):<7.1f} GB/s)" + # noqa: E501
f" {[repr(m) if TRACEMETA >= 2 else str(m) for m in self.metadata] if self.metadata else ''}"))
self.prg.first_run = False

View File

@@ -110,7 +110,7 @@ reduceop_fusor = PatternMatcher([
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.SWIZZLE, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push SWIZZLE(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.ASSIGN, UOps.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
@@ -139,7 +139,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val)
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
@@ -157,9 +157,10 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, realizes, assign_targets, cache) for x in buf.srcs)
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
if buf.op is MetaOps.CONTIGUOUS:
assert buf in outputs, f"{buf.op} must be writable"
return in_uops[0]
if buf.op is MetaOps.ASSIGN: return cache.setdefault((buf, st), UOp(UOps.ASSIGN, dtype, (in_uops[1].src[0], in_uops[0])))
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops))
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))

View File

@@ -103,19 +103,23 @@ class UOps(FastEnum):
VALID = auto()
SPECIAL = auto()
NOOP = auto()
GEP = auto()
# math ops
CAST = auto()
BITCAST = auto()
VECTORIZE = auto()
ALU = auto()
REDUCE = auto()
REDUCE_AXIS = auto()
# helper ops
GEP = auto()
VECTORIZE = auto()
CAST = auto()
BITCAST = auto()
# loads before math
LOAD = auto()
# math ops
ALU = auto()
WMMA = auto()
# memory/assignment ops
LOAD = auto()
# assignment ops
STORE = auto()
ASSIGN = auto()
@@ -128,7 +132,7 @@ class UOps(FastEnum):
ENDRANGE = auto()
ENDIF = auto()
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.CONST}
BUFFER_UOPS = {UOps.LOAD, UOps.STORE, UOps.VALID}
COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, BinaryOps.XOR, BinaryOps.AND, BinaryOps.OR}
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
@@ -144,7 +148,7 @@ class UOp(MathTrait):
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
@property
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}
@functools.cached_property
def st(self) -> Optional[ShapeTracker]:
if not self.has_st: return None
@@ -170,16 +174,14 @@ class UOp(MathTrait):
return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else repr(self.arg) if isinstance(self.arg, Variable) else self.arg
# *** uop syntactic sugar
@property
def st_loc(self) -> int: return 0 if self.op is UOps.CONST else 1
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
ret = self.src[self.st_loc]
ret = self.src[0 if self.op is UOps.VALID else 1]
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg
def sink(self, *srcs:UOp): return UOp(UOps.SINK, dtypes.void, (self,)+srcs)
def swizzle(self, st:ShapeTracker): return UOp(UOps.SWIZZLE, self.dtype, (self,), st)
def const_like(self, b:ConstType|Variable|Tuple[ConstType]): return UOp.const(self.dtype, b)
def const_like(self, b:ConstType|Variable|Tuple[ConstType, ...]): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -215,25 +217,22 @@ class UOp(MathTrait):
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
@staticmethod
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType):
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, UOp.const(dtype, min_val), UOp.const(dtype, max_val)))
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def reduce(self, op, *rng): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op is UOps.SHAPETRACKER: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1].arg, x.arg[2].arg) for x in self.vars()]), key=lambda v: v.expr)
return sorted(set.union(*st_vars, [Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.vars()]), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
@@ -259,8 +258,7 @@ class UOp(MathTrait):
@functools.cached_property
def _min_max(self) -> Tuple[ConstType, ConstType]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.arg:
return self.arg[1].arg, self.arg[2].arg if self.arg[2].op is UOps.CONST else dtypes.max(self.dtype)
if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2]
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
if self.op is UOps.EXPAND: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
@@ -331,7 +329,7 @@ def exec_alu(op:Op, dtype:DType, operands):
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")
@@ -380,8 +378,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
def get_location() -> Tuple[str, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the UPat
while frm.f_back is not None and any(fp == frm.f_back.f_code.co_filename.split("/")[-1] for fp in {"ops.py", "uopgraph.py", "schedule.py"}):
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and frm.f_back.f_code.co_filename.split("/")[-1] in {"ops.py", "uopgraph.py", "schedule.py", "lowerer.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@functools.lru_cache(None)
@@ -406,7 +404,7 @@ class UPat(MathTrait):
# repeat if it's a UPat
elif isinstance(src, UPat): self.src = [itertools.repeat(src)]
self.allowed_len: int = 0 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
self.allowed_len: int = -1 if allow_any_len or isinstance(src, UPat) or src is None else len(src)
self.location = location or get_location()
if custom_early_reject is not None: self.early_reject = custom_early_reject
@@ -459,7 +457,7 @@ class UPat(MathTrait):
(self.dtype is not None and uop.dtype not in self.dtype) or \
(self.arg is not None and self.arg != uop.arg) or \
(self.op is not None and uop.op not in self.op) or \
(self.allowed_len != 0 and len(uop.src) != self.allowed_len): return []
(self.allowed_len != -1 and len(uop.src) != self.allowed_len): return []
if self.src is None: return [store]
res: List[Dict[str, UOp]] = []
for vp in self.src:
@@ -488,11 +486,11 @@ class PatternMatcher:
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
if not early_reject.issubset(ler): continue
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
return None
# *** tracking pattern matcher ***
@@ -512,7 +510,7 @@ class TrackedPatternMatcher(PatternMatcher):
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
def rewrite(self, uop:UOp) -> Optional[UOp]:
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
@@ -521,7 +519,7 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None:
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None:
match_stats[p][0] += 1
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
@@ -536,25 +534,26 @@ if TRACK_MATCH_STATS:
import atexit, pickle
@atexit.register
def print_match_stats():
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
if TRACK_MATCH_STATS >= 2:
with open("/tmp/rewrites.pkl", "wb") as f:
print(f"rewrote {len(contexts)} graphs and applied {sum(len(x.rewrites) for x in contexts)} rules, saved to /tmp/rewrites.pkl")
pickle.dump(contexts, f)
if getenv("VIZ"):
import viz.serve
viz.serve.main()
return viz.serve.main()
ret = [0,0,0.0,0.0]
for k,v in sorted(list(match_stats.items()), key=lambda x: x[1][2]):
loc_str = f"{k.location[0].split('/')[-1]}:{k.location[1]}"
if v[1] != 0: print(f"{v[0]:6d} / {v[1]:7d} -- {v[3]*1000.:9.2f} / {v[2]*1000.:9.2f} ms -- {loc_str:15s}", k.printable())
ret = [x+y for x,y in zip(ret, v)]
print(f"{ret[0]:6d} / {ret[1]:7d} -- {ret[3]*1000.:9.2f} / {ret[2]*1000.:9.2f} ms -- TOTAL")
# *** simple graph rewrite engine ***
class RewriteContext:
def __init__(self, pm):
def __init__(self, pm, ctx):
self.pm: PatternMatcher = pm
self.ctx = ctx
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
@@ -563,33 +562,36 @@ class RewriteContext:
if found := self.nodes.get(replace_source): self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
return RewriteContext(pm).rewrite(sink)
return RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****
# this is the matcher for the final rendered UOps
# matcher functions returns True or False (or None to not match)
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)),
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType)) and not x.dtype.local),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType) and x.dtype.local),
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], UOp) and isinstance(x.arg[2], UOp)),
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], int) and isinstance(x.arg[2], int)),
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
(UPat(UOps.SPECIAL, src=()), lambda: True),
# no pyint allowed here!
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),
# TODO: confirm the args of both of these are shapetrackers
(UPat(UOps.SHAPETRACKER, src=()), lambda: True),
(UPat(UOps.SWIZZLE, src=(UPat(),)), lambda: True),
(UPat(UOps.CONST, name="x"),
lambda x: x.dtype == x.dtype.scalar() and (isinstance(x.arg, Variable) and x.src) or (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
(UPat(UOps.VALID, dtypes.bool, (UPat(UOps.SHAPETRACKER),)), lambda: True),
(UPat(UOps.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
# early LOAD has a <buf, shapetracker, store?>
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER))), lambda: True),
@@ -616,13 +618,13 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True),
(UPat(UOps.ASSIGN, src=(UPat((UOps.DEFINE_ACC, UOps.DEFINE_GLOBAL)), UPat())), lambda: True),
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
# early WMMA has 2 args, <x, w>
(UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True),
# late WMMA has 3 args, <x, w, acc>
# all WMMA has 3 args, <x, w, acc>
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
(UPat(UOps.CONTRACT, name="x"), lambda x: x.dtype.count == prod(y[1] for y in x.arg)),
(UPat(UOps.EXPAND, name="x"), lambda x: x.src[0].dtype.count == prod(y[1] for y in x.arg)),
# if has a <gate, barrier>
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
@@ -636,7 +638,7 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
(UPat(UOps.SINK), lambda: True),
(UPat(UOps.SINK, dtypes.void), lambda: True),
# PTX LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),

View File

@@ -34,7 +34,7 @@ class Program:
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg))
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2]))
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL:

View File

@@ -56,14 +56,14 @@ class CStyleLanguage(Renderer):
return (self.render_cast(val, dtype) if dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val)
# returns a str expression of the loaded value with the output type
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
def render_load(self, output_dtype, buf_name, buf_dtype, idx) -> str:
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes.float.vec(4), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16:
return f"vload_half{'' if output_dtype.count == 1 else str(output_dtype.count)}(0, {buf_name}+{idx})"
if output_dtype.count > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
return f"*(({self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix}{self.render_dtype(output_dtype)}*)({buf_name}+{idx}))" # noqa: E501
return f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
@@ -78,14 +78,14 @@ class CStyleLanguage(Renderer):
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
# returns a str statement that does the store
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx:str, local=False) -> str:
def render_store(self, buf_name:str, buf_dtype:Union[ImageDType, PtrDType], var_name:str, var_dtype:DType, idx:str) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes.float.vec(4), f"images must be float4, getting {var_dtype}"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16:
return f"vstore_half{'' if var_dtype.count == 1 else str(var_dtype.count)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.count > 1:
prefix = self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix
prefix = self.smem_prefix if buf_dtype.local and self.smem_prefix_for_cast else self.buffer_prefix
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
@@ -124,8 +124,9 @@ class CStyleLanguage(Renderer):
kk("}")
elif uop is UOps.STORE:
# mark DEFINE_GLOBAL buf as writable
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
else:
if uop is UOps.RANGE:
@@ -150,7 +151,7 @@ class CStyleLanguage(Renderer):
bufs[u] = (args[0], (dtype,False))
r[u] = args[0]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
@@ -226,6 +227,11 @@ class ClangRenderer(CStyleLanguage):
class OpenCLRenderer(CStyleLanguage):
device = "GPU"
code_for_op = {**CStyleLanguage().code_for_op,
#UnaryOps.SQRT: lambda x,dtype: f"native_sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"native_recip({x})",
#UnaryOps.EXP2: lambda x,dtype: f"native_exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"native_log2({x})",
UnaryOps.SIN: lambda x,dtype: f"native_sin({x})"} if getenv("NATIVE_MATH") else CStyleLanguage().code_for_op
# language options
kernel_prefix = "__kernel "
buffer_prefix = "__global "

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, ctypes, functools, mmap, struct, array, decimal, math
from types import SimpleNamespace
from typing import Tuple, List, Dict, Any
from typing import Tuple, List, Any, cast
from tinygrad.device import BufferOptions, HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState, hcq_command
from tinygrad.runtime.autogen import kgsl, adreno, libc
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
@@ -9,6 +9,8 @@ from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, DEBUG, fromimport
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
def _qreg_exec(reg, __val=0, **kwargs):
for k, v in kwargs.items():
__val |= (getattr(adreno, f'{reg[4:]}_{k.upper()}') if v else 0) if type(v) is bool else (v << getattr(adreno, f'{reg[4:]}_{k.upper()}__SHIFT'))
@@ -123,39 +125,38 @@ class QCOMComputeQueue(HWComputeQueue):
qreg.a6xx_sp_cs_pvt_mem_size(totalpvtmemsize=prg.pvtmem_size_total))
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_SHADER, num_unit=prg.kernargs_alloc_size // 4),
state_block=adreno.SB6_CS_SHADER, num_unit=1024 // 4),
*data64_le(args_state.ptr))
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_SHADER, num_unit=round_up(prg.image_size, 128) // 128),
*data64_le(prg.lib_gpu.va_addr))
self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc,
qreg.a6xx_hlsq_cs_cntl(constlen=prg.kernargs_alloc_size // 4, enabled=True))
self.reg(adreno.REG_A6XX_HLSQ_CONTROL_2_REG, 0xfcfcfcfc, 0xfcfcfcfc, 0xfcfcfcfc, 0xfc, qreg.a6xx_hlsq_cs_cntl(constlen=1024 // 4, enabled=True))
self.reg(adreno.REG_A6XX_SP_CS_PVT_MEM_HW_STACK_OFFSET, qreg.a6xx_sp_cs_pvt_mem_hw_stack_offset(prg.hw_stack_offset))
self.reg(adreno.REG_A6XX_SP_CS_INSTRLEN, qreg.a6xx_sp_cs_instrlen(prg.image_size // 4))
if hasattr(args_state, 'samplers_ptr'):
if args_state.prg.samp_cnt > 0:
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_SHADER, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_TEX, num_unit=args_state.samplers_cnt),
*data64_le(args_state.samplers_ptr.va_addr))
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.samplers_ptr.va_addr))
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
*data64_le(args_state.ptr + args_state.prg.samp_off))
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.device._border_color_base()))
if hasattr(args_state, 'descriptors_ptr'):
if args_state.prg.tex_cnt > 0:
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_TEX, num_unit=args_state.descriptors_cnt),
*data64_le(args_state.descriptors_ptr.va_addr))
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.descriptors_ptr.va_addr))
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt),
*data64_le(args_state.ptr + args_state.prg.tex_off))
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
if hasattr(args_state, 'ibos_ptr'):
if args_state.prg.ibo_cnt > 0:
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST6_IBO, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_SHADER, num_unit=args_state.ibos_cnt),
*data64_le(args_state.ibos_ptr.va_addr))
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ibos_ptr.va_addr))
state_block=adreno.SB6_CS_SHADER, num_unit=args_state.prg.ibo_cnt),
*data64_le(args_state.ptr + args_state.prg.ibo_off))
self.reg(adreno.REG_A6XX_SP_CS_IBO, *data64_le(args_state.ptr + args_state.prg.ibo_off))
self.reg(adreno.REG_A6XX_SP_CS_CONFIG,
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.samplers_cnt, ntex=args_state.descriptors_cnt, nibo=args_state.ibos_cnt))
qreg.a6xx_sp_cs_config(enabled=True, nsamp=args_state.prg.samp_cnt, ntex=args_state.prg.tex_cnt, nibo=args_state.prg.ibo_cnt))
self.cmd(adreno.CP_RUN_OPENCL, 0)
self._cache_flush(write_back=True, invalidate=False, sync=False, memsync=False)
@@ -175,46 +176,27 @@ class QCOMComputeQueue(HWComputeQueue):
class QCOMArgsState(HCQArgsState):
def __init__(self, ptr:int, prg:QCOMProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()):
super().__init__(ptr, prg, bufs, vals=vals)
self.ibos_cnt, self.descriptors_cnt, self.samplers_cnt = 0, 0, 0
ctypes.memset(ptr, 0, 1024)
ctypes.memset(self.ptr, 0, 1024)
if len(bufs) + len(vals) != len(prg.buf_info): raise RuntimeError(f'incorrect args size given={len(bufs)+len(vals)} != want={len(prg.buf_info)}')
self.buf_info, self.args_info = prg.buf_info[:len(bufs)], prg.buf_info[len(bufs):]
if len(bufs) + len(vals) != len(prg.buffs_info): raise RuntimeError(f'incorrect args size given={len(bufs)} != want={len(prg.buffs_info)}')
self.boffs, self.aoffs = prg.buffs_info[:len(bufs)], prg.buffs_info[len(bufs):]
for i, v in enumerate(vals): self.update_var(i, v)
for cnst_val, cnst_off, cnst_sz in prg.consts_info:
ctypes.memmove(self.ptr + cnst_off, (ctypes.c_int8 * cnst_sz).from_buffer_copy(cnst_val.to_bytes(cnst_sz, byteorder='little')), cnst_sz)
samplers: List[Any] = []
descriptors: List[Any] = []
ibos: List[Any] = []
self.i2descr: Dict[int, int] = {}
self.i2ibo: Dict[int, int] = {}
for i, b in enumerate(bufs):
if not hasattr(b, 'samplers') and not hasattr(b, 'descriptor') and not hasattr(b, 'ibo'): self.update_buffer(i, b)
elif self.boffs[i][1]: ibos, self.i2ibo = [*ibos, *getattr(b, 'ibo')], {**self.i2ibo, i: len(ibos)}
else:
samplers, descriptors = [*samplers, *getattr(b, 'samplers')], [*descriptors, *getattr(b, 'descriptor')]
self.i2descr[i] = len(descriptors) - 1
def alloc_tex_gpu(data, chunk_size) -> Tuple[HCQBuffer, int]:
tex_gpu = self.prg.device.allocator.alloc(len(data) * 4, BufferOptions(nolru=True, cpu_access=True))
to_mv(tex_gpu.va_addr, len(data) * 4).cast('I')[:] = array.array('I', data)
return tex_gpu, len(data) // chunk_size
if len(samplers): self.samplers_ptr, self.samplers_cnt = alloc_tex_gpu(samplers, 4)
if len(descriptors): self.descriptors_ptr, self.descriptors_cnt = alloc_tex_gpu(descriptors, 16)
if len(ibos): self.ibos_ptr, self.ibos_cnt = alloc_tex_gpu(ibos, 16)
def __del__(self):
for ptr in ('samplers_ptr', 'descriptors_ptr', 'ibos_ptr'):
if hasattr(self, ptr): self.prg.device.allocator.free((x:=getattr(self, ptr)), x.size, BufferOptions(nolru=True, cpu_access=True))
if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
for i, b in enumerate(cast(List[QCOMBuffer], bufs)):
if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc)
elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo)
else: self.update_buffer(i, b)
for i, v in enumerate(vals): self.update_var(i, v)
def update_buffer(self, index:int, buf:HCQBuffer):
if (descr:=self.i2descr.get(index, None)) is not None: to_mv(self.descriptors_ptr.va_addr + 16 * descr + 4 * 4, 8).cast('Q')[0] = buf.va_addr
elif (ibo:=self.i2ibo.get(index, None)) is not None: to_mv(self.ibos_ptr.va_addr + 16 * ibo + 4 * 4, 8).cast('Q')[0] = buf.va_addr
else: to_mv(self.ptr + self.boffs[index][0], 8).cast('Q')[0] = buf.va_addr
if self.buf_info[index].type is not BUFTYPE_BUF: to_mv(self.ptr+self.buf_info[index].offset+0x10, 8).cast('Q')[0] = buf.va_addr
else: to_mv(self.ptr + self.buf_info[index].offset, 8).cast('Q')[0] = buf.va_addr
def update_var(self, index:int, val:int): to_mv(self.ptr + self.aoffs[index][0], 8).cast('Q')[0] = val
def update_var(self, index:int, val:int): to_mv(self.ptr + self.args_info[index].offset, 8).cast('Q')[0] = val
class QCOMProgram(HCQProgram):
def __init__(self, device: QCOMDevice, name: str, lib: bytes):
@@ -232,7 +214,7 @@ class QCOMProgram(HCQProgram):
self.max_threads = min(1024, ((384 * 32) // (max(1, (self.fregs + round_up(self.hregs, 2) // 2)) * 128)) * 128)
device._ensure_stack_size(self.hw_stack_offset * 4)
super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=1024)
super().__init__(QCOMArgsState, self.device, self.name, kernargs_alloc_size=2048 + (self.tex_cnt + self.ibo_cnt) * 0x40 + self.samp_cnt * 0x10)
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
@@ -253,19 +235,34 @@ class QCOMProgram(HCQProgram):
self.pvtmem, self.shmem = _read_lib(image_desc_off+0xc8), _read_lib(image_desc_off+0xd8)
# Fill up constants and buffers info
self.buffs_info, self.consts_info = [], []
self.buf_info, self.consts_info = [], []
samplers_count = _read_lib(image_desc_off + 0xdc)
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samplers_count
while (bdoff + 16 <= len(self.lib)):
length, _, _, offset_words = struct.unpack("I" * 4, self.lib[bdoff:bdoff+16])
# Collect sampler info.
self.samp_cnt = _read_lib(image_desc_off + 0xdc)
assert self.samp_cnt <= 1, "Up to one sampler supported"
if self.samp_cnt:
self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode),
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0]
# Collect kernel arguments (buffers) info.
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt
while bdoff + 32 <= len(self.lib):
length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32])
if length == 0: break
self.buffs_info.append((offset_words * 4, struct.unpack("I", self.lib[bdoff+0x3c:bdoff+0x40])[0] == 0x0))
self.buf_info.append(SimpleNamespace(offset=offset_words * 4, type=typ))
bdoff += length
# Setting correct offsets to textures/ibos.
self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info)
self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt
cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off
for x in self.buf_info:
if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40
elif x.type is BUFTYPE_TEX: x.offset, cur_tex_off = cur_tex_off, cur_tex_off + 0x40
if _read_lib(0xb0) != 0: # check if we have constants.
cdoff = _read_lib(0xac)
while (cdoff + 40 <= image_offset):
while cdoff + 40 <= image_offset:
cnst, offset_words, _, is32 = struct.unpack("I", self.lib[cdoff:cdoff+4])[0], *struct.unpack("III", self.lib[cdoff+16:cdoff+28])
self.consts_info.append((cnst, offset_words * (sz_bytes:=(2 << is32)), sz_bytes))
cdoff += 40
@@ -277,6 +274,10 @@ class QCOMProgram(HCQProgram):
def __del__(self):
if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
class QCOMBuffer(HCQBuffer):
def __init__(self, va_addr:int, size:int, desc=None, ibo=None, pitch=None, real_stride=None):
self.va_addr, self.size, self.desc, self.ibo, self.pitch, self.real_stride = va_addr, size, desc, ibo, pitch, real_stride
class QCOMAllocator(HCQAllocator):
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
if options.image is not None:
@@ -286,33 +287,37 @@ class QCOMAllocator(HCQAllocator):
granularity = 128 if options.image.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
pitch = round_up(imgw * 4 * options.image.itemsize, 1 << pitchalign) + pitch_add
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
texture = self.device._gpu_alloc(pitch * round_up(imgh, 16), kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
# Extend HCQBuffer with texture-related info.
texture.samplers, texture.descriptor, texture.ibo = [0] * 4, [0] * 16, [0] * 16
# Compiled sampler (always the same in tinygrad).
texture.samplers[0] = qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode)
texture.samplers[1] = qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True)
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
texture.descriptor[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
texture.descriptor[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
texture.descriptor[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6)
texture.descriptor[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
texture.ibo = [texture.descriptor[0] & (~0xffff), *texture.descriptor[1:len(texture.descriptor)]]
texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
return texture
return self.device._gpu_alloc(size, map_to_cpu=True)
def copyin(self, dest:HCQBuffer, src:memoryview): ctypes.memmove(dest.va_addr, from_mv(src), src.nbytes)
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
while src_off < src_size:
ctypes.memmove(dest_addr+dest_off, src_addr+src_off, real_size)
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
def copyin(self, dest:HCQBuffer, src:memoryview):
if hasattr(qd:=cast(QCOMBuffer, dest), 'pitch'): self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
def copyout(self, dest:memoryview, src:HCQBuffer):
self.device.synchronize()
ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
if hasattr(qs:=cast(QCOMBuffer, src), 'pitch'): self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
def as_buffer(self, src:HCQBuffer) -> memoryview:
self.device.synchronize()

View File

@@ -62,13 +62,11 @@
stroke-width: 1.5px;
}
.graph {
grid-column: span 10;
width: 70%;
position: relative;
}
.main-container {
display: grid;
grid-template-columns: repeat(14, 1fr);
gap: 8px;
display: flex;
padding: 12px;
width: 100%;
height: 100%;
@@ -79,12 +77,16 @@
background-color: #111111;
border-radius: 8px;
padding: 8px;
position: relative;
}
.container > * + * {
margin-top: 12px;
}
.main-container > * + * {
margin-left: 8px;
}
.kernel-list {
grid-column: span 1;
width: 10%;
display: flex;
flex-direction: column;
overflow-y: auto;
@@ -96,9 +98,18 @@
margin-top: 4px;
}
.metadata {
grid-column: span 3;
width: 20%;
overflow-y: auto;
}
#resize-handle {
position: absolute;
left: 0;
top: 0;
bottom: 0;
width: 20px;
cursor: w-resize;
background-color: transparent;
}
.rewrite-list {
display: flex;
flex-wrap: wrap;
@@ -137,12 +148,17 @@
transform: rotate(360deg);
}
}
.status {
color: #EC5D5E;
font-weight: bold;
}
</style>
</head>
<body>
<div class="main-container">
<div class="container kernel-list"></div>
<div class="container graph">
<div class="status"></div>
<svg>
<g id="render"></g>
</svg>
@@ -150,12 +166,16 @@
<div class="container metadata"></div>
</div>
<script>
function renderGraph(graph) {
const g = new dagreD3.graphlib.Graph().setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: #242424" : "display: none;"});
for ([k,u] of Object.entries(graph)) {
g.setNode(k, {label: u[0], style: `fill: ${u[4]}; rx: 8; ry: 8;` });
for (src of u[2]) {
g.setEdge(src, k)
g.setEdge(src, k, {curve: d3.curveBasis})
}
if (additions.includes(parseInt(k))) {
g.setParent(k, "addition");
}
}
const svg = d3.select("svg");
@@ -170,26 +190,38 @@
const render = new dagreD3.render();
render(inner, g);
}
async function checkStatus() {
active = true;
try {
active = (await fetch("/")).ok;
} catch {
active = false;
}
document.querySelector('.status').textContent = active ? `` : `Connection to localhost:8000 failed, is VIZ=1 running?`;
}
var ret = {};
var cache = {};
var kernels = null;
var currentUOp = 0;
var currentKernel = 0;
var currentRewrite = 0;
var expandKernel = false;
async function main() {
checkStatus();
// ** kernel list
if (kernels == null) {
kernels = await (await fetch("/kernels")).json();
currentKernel = 0;
}
const kernelList = document.querySelector(".container.kernel-list");
kernelList.innerHTML = "";
Object.entries(kernels).forEach(([k,uopRewrites], i) => {
kernelUl = Object.assign(document.createElement("ul"), { innerHTML: `<p>${k}</p>`, key: `kernel-${i}`, className: i === currentKernel ? "active" : "" });
uopRewrites.forEach((u, j) => {
kernels.forEach((k, i) => {
kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "" });
const p = Object.assign(document.createElement("p"), {id: `kernel-${k.name}`, innerText: k.name})
kernelUl.appendChild(p)
k.ctxs.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: u, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" })
if (i !== currentKernel) {
rwUl.style.display = "none";
}
rwUl.style.display = i === currentKernel && expandKernel ? "block" : "none";
rwUl.onclick = (e) => {
e.stopPropagation();
currentUOp = j;
@@ -199,11 +231,16 @@
}
kernelUl.appendChild(rwUl)
})
kernelUl.onclick = (e) => {
if (i === currentKernel) return;
p.onclick = (e) => {
if (i === currentKernel) {
expandKernel = !expandKernel;
main();
return;
}
currentKernel = i;
currentUOp = 0;
currentRewrite = 0;
expandKernel = true;
main();
}
kernelList.appendChild(kernelUl);
@@ -217,17 +254,46 @@
ret = await (await fetch(`/graph?kernel_idx=${currentKernel}&uop_idx=${currentUOp}`)).json();
cache[cacheKey] = ret;
}
renderGraph(ret[0].graphs[currentRewrite][0]);
renderGraph(ret[0].graphs[currentRewrite], ret[0].additions[currentRewrite]);
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
metadata.appendChild(Object.assign(document.createElement("div"), { id: "resize-handle" }));
metadata.appendChild(Object.assign(document.createElement("pre"), { textContent: ret[0].loc }));
const resizeHandle = document.getElementById("resize-handle");
let startX;
let containerWidth;
let metadataWidth;
resizeHandle.addEventListener("mousedown", (e) => {
e.preventDefault();
metadata.style.userSelect = "none";
startX = e.clientX;
containerWidth = document.querySelector(".main-container").getBoundingClientRect().width;
metadataWidth = metadata.getBoundingClientRect().width;
document.documentElement.addEventListener("mousemove", resize, false);
document.documentElement.addEventListener("mouseup", stopResize, false);
});
function resize(e) {
const change = e.clientX - startX;
const newWidth = ((metadataWidth-change) / containerWidth) * 100;
if (newWidth >= 20 && newWidth <= 50) {
metadata.style.width = `${newWidth}%`;
document.querySelector(".graph").style.width = `${100-newWidth-10}%`;
}
}
function stopResize(e) {
document.documentElement.removeEventListener("mousemove", resize, false);
document.documentElement.removeEventListener("mouseup", stopResize, false);
metadata.style.userSelect = "initial";
}
ret[0].extra[currentRewrite].forEach((e, i) => {
if (e.length == 0) return;
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${e}</code>`, className: "code-block" });
metadata.appendChild(pre);
})
if (ret[0].code !== "") {
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${ret[0].code}</code>`, className: "code-block" });
if (kernels[currentKernel].code !== "") {
const pre = Object.assign(document.createElement("pre"), { innerHTML: `<code>${kernels[currentKernel].code}</code>`, className: "code-block" });
metadata.appendChild(pre);
}
// ** rewrite list
@@ -263,7 +329,30 @@
}
}
document.addEventListener("keydown", async function(event) {
// up and down change the UOp from the list
// up and down change the UOp or kernel from the list
if (!expandKernel) {
if (event.key == "ArrowUp") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.max(0, currentKernel-1)
return main()
}
if (event.key == "ArrowDown") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
return main()
}
}
if (event.key == "Enter") {
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
expandKernel = !expandKernel;
main();
}
if (event.key == "ArrowUp") {
event.preventDefault()
currentRewrite = 0;
@@ -273,7 +362,7 @@
if (event.key == "ArrowDown") {
event.preventDefault()
currentRewrite = 0;
const totalUOps = Array.from(Object.values(kernels))[currentKernel].length-1;
const totalUOps = kernels[currentKernel].ctxs.length-1;
currentUOp = Math.min(totalUOps, currentUOp+1)
main()
}
@@ -290,6 +379,7 @@
main()
}
})
setInterval(checkStatus, 5000);
main()
</script>
</body>

View File

@@ -1,7 +1,8 @@
#!/usr/bin/env python3
from __future__ import annotations
from typing import Dict, List, Tuple
import pickle, os, sys, time, threading, webbrowser, json, difflib, contextlib
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from urllib.parse import parse_qs, urlparse
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad import Device
@@ -11,6 +12,37 @@ from tinygrad.engine.graph import uops_colors, word_wrap
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import full_ast_rewrite
# **** /graph - detailed UOp + rewrites
@dataclass(frozen=True)
class UOpRet:
loc: str
graphs: List[UOp] # snapshot of the entire AST after each rewrite
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
additions: List[List[int]]
@staticmethod
def from_ctx(ctx:TrackedRewriteContext) -> UOpRet:
uops: List[UOp] = [ctx.sink]
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
extra: List[List[str]] = [[str(ctx.sink)]]
additions: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
if pattern.location[0].split("/")[-1] == "ops.py": continue
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
additions.append([id(x) for x in rewritten.sparents])
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, uops, diffs, extra, additions)
def to_json(self) -> Dict: return {**asdict(self), "graphs": list(map(uop_to_json, self.graphs))}
def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
assert isinstance(x, UOp)
graph: Dict[int, Tuple[str, str, List[int], str, str]] = {}
@@ -22,51 +54,33 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
@dataclass(frozen=True)
class UOpRet:
loc: str
graphs: List[Tuple[UOp, UOp, UOp, UOp]] # snapshot of the entire AST after each rewrite
diffs: List[Tuple[str, Tuple[str, int], List[str]]] # the diffs for each rewrite
extra: List[List[str]] # these become code blocks in the UI
def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
if (found:=replaces.get(base.key)): return found
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
return ret
def create_graph(ctx:TrackedRewriteContext) -> UOpRet:
uops: List[UOp] = [ctx.sink]
graphs: List[Tuple[UOp, UOp, UOp, UOp]] = [(ctx.sink, ctx.sink, ctx.sink, ctx.sink)]
diffs: List[Tuple[str, Tuple[str, int], List[str]]] = []
extra: List[List[str]] = [[str(ctx.sink)]]
seen_replaces: Dict[bytes, UOp] = {}
for i, (first, rewritten, pattern) in enumerate(ctx.rewrites):
if pattern.location[0].split("/")[-1] == "ops.py": continue
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
# update ret data
diffs.append((str(pattern), pattern.location, list(difflib.unified_diff(str(first).splitlines(), str(rewritten).splitlines()))))
graphs.append((new_sink, uops[-1], rewritten, first))
uops.append(new_sink)
extra.append([str(new_sink)])
return UOpRet(ctx.loc, graphs, diffs, extra)
# **** /kernels - Overview of the kernel
def get_ctx_groups(contexts:List[TrackedRewriteContext]) -> Dict[str, Tuple[List[TrackedRewriteContext], str]]:
ctx_groups: Dict[str, Tuple[List[TrackedRewriteContext], str]] = {}
@dataclass(frozen=True)
class KernelRet:
name: str
code: str
ctxs: Dict[Tuple[str, bytes], TrackedRewriteContext]
def to_json(self) -> Dict:
return {"name":self.name, "code":self.code, "ctxs":[x.loc for x in self.ctxs.values()]}
def load_kernels(contexts:List[TrackedRewriteContext]) -> List[KernelRet]:
ret: Dict[str, KernelRet] = {}
kernel_name = ""
code = ""
for ctx in contexts:
if ctx.loc.split("/")[-1].split(":")[0] == "schedule.py":
with Context(TRACK_MATCH_STATS=0): kernel_name, code = (prg:=get_runner(Device.DEFAULT, full_ast_rewrite(ctx.sink)).p).name, prg.src
elif ctx.kernel_name is not None: kernel_name = ctx.kernel_name
if ctx_groups.get(k:=to_function_name(kernel_name)) is None: ctx_groups[k] = ([], code)
# TODO: make ansi play nice with css
ctx_groups[to_function_name(kernel_name)][0].append(ctx)
return ctx_groups
elif ctx.kernel_name is not None: kernel_name, code = ctx.kernel_name, ""
if ret.get(k:=to_function_name(kernel_name)) is None: ret[k] = KernelRet(k, code, {})
ret[k].ctxs[(ctx.loc, ctx.sink.key)] = ctx
return list(ret.values())
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
@@ -87,20 +101,18 @@ class Handler(BaseHTTPRequestHandler):
self.send_header("Content-type", "application/json")
self.end_headers()
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
ctx_groups = get_ctx_groups(contexts)
ret = json.dumps({k:[x.loc for x in v[0]] for k,v in ctx_groups.items()}).encode()
kernels = load_kernels(contexts)
ret = json.dumps([x.to_json() for x in kernels]).encode()
elif url.path == "/graph":
query = parse_qs(url.query)
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
with open("/tmp/rewrites.pkl", "rb") as f: contexts: List[TrackedRewriteContext] = pickle.load(f)
ctx_groups = get_ctx_groups(contexts)
group, code = ctx_groups[list(ctx_groups.keys())[int(query["kernel_idx"][0])]]
g = create_graph(group[int(query["uop_idx"][0])])
rest = [x.loc for x in group]
ret = json.dumps(({"loc": g.loc, "graphs": [[uop_to_json(x) for x in graph] for graph in g.graphs],
"diffs": g.diffs, "extra": g.extra, "code": code}, rest)).encode()
kernels = load_kernels(contexts)
k = kernels[int(query["kernel_idx"][0])]
g = UOpRet.from_ctx(list(k.ctxs.values())[int(query["uop_idx"][0])])
ret = json.dumps((g.to_json(), [x.loc for x in k.ctxs.values()])).encode()
else:
self.send_response(404)
ret = b""

View File

@@ -6,10 +6,10 @@ from tinygrad import Tensor
from tinygrad.engine.realize import lower_schedule
from tinygrad.ops import UOp, UOps, graph_rewrite, PatternMatcher, UPat, contexts, KernelInfo, BinaryOps
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.helpers import CI, all_same, DEBUG, colored, getenv
from tinygrad.helpers import CI, Context, all_same, DEBUG, colored, getenv
from tinygrad.codegen.uopgraph import constant_folder, devectorize, float4_folding
from test.external.process_replay.helpers import print_diff
from viz.serve import create_graph, get_ctx_groups
from viz.serve import UOpRet, load_kernels
class TestViz(unittest.TestCase):
def tearDown(self) -> None:
@@ -19,12 +19,11 @@ class TestViz(unittest.TestCase):
def assert_valid_ctx(self, contexts):
assert len(contexts) != 0
for i,ctx in enumerate(contexts):
try: ret = create_graph(ctx)
try: ret = UOpRet.from_ctx(ctx)
except Exception as e:
print(colored(f"failed to create graph for ctx {i}", "red"))
raise e
rewrites = [x[0] for x in ret.graphs]
for j,(x,y) in enumerate(zip(rewrites, rewrites[1:])):
for j,(x,y) in enumerate(zip(ret.graphs, ret.graphs[1:])):
if x.key == y.key:
raise AssertionError(f"failed to generate the correct diff at rewrite {j} ctx {i}")
@@ -45,10 +44,10 @@ class TestViz(unittest.TestCase):
schedule2 = Tensor.randn(4, 4).contiguous().schedule()
list(lower_schedule(schedule1))
list(lower_schedule(schedule2))
ret = get_ctx_groups(contexts)
assert len(ret.keys()) == 2
assert all(len([x for x in ctxs if "schedule" in x.loc]) != 0 for ctxs,_ in ret.values())
assert all(len([x for x in ctxs if "uopgraph" in x.loc]) != 0 for ctxs,_ in ret.values())
ret = load_kernels(contexts)
assert len(ret) == 2
assert all(len([x for x in y.ctxs.values() if "schedule" in x.loc]) != 0 for y in ret)
assert all(len([x for x in y.ctxs.values() if "uopgraph" in x.loc]) != 0 for y in ret)
def test_gemm_diff(self):
x = Tensor.empty(64, 64).realize()
@@ -67,8 +66,8 @@ class TestViz(unittest.TestCase):
])
ret = graph_rewrite(sink, pm)
if DEBUG >= 4: print_diff(sink, ret)
g = create_graph(contexts[0])
assert g.graphs[-1][0].key == ret.key
g = UOpRet.from_ctx(contexts[0])
assert g.graphs[-1].key == ret.key
self.assert_valid_ctx(contexts)
def test_devectorize_viz(self):
@@ -115,5 +114,28 @@ class TestViz(unittest.TestCase):
simple_pm.rewrite(UOp.const(dtypes.int, 2))
self.assertEqual(len(contexts), 0)
def test_dedup_ast(self):
contexts.clear()
a = Tensor.randn(4, 4)+2
b = Tensor.randn(4, 4)+2
Tensor.schedule(a, b)
kernels = load_kernels(contexts)
self.assertEqual(len(kernels), 1)
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 1)
def test_no_dedup_different_opts(self):
contexts.clear()
a = Tensor.empty(4, 4)+Tensor.empty(4, 4)
s = a.schedule()
with Context(NOOPT=1): list(lower_schedule(s.copy()))
with Context(NOOPT=0): list(lower_schedule(s.copy()))
kernels = load_kernels(contexts)
self.assertEqual(len(kernels), 2)
schedule_ctxs = [x for x in kernels[0].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 1)
schedule_ctxs = [x for x in kernels[1].ctxs.values() if x.loc.split("/")[-1].split(":")[0] == "schedule.py"]
self.assertEqual(len(schedule_ctxs), 0)
if __name__ == "__main__":
unittest.main()