From 9f6d545a16808c7d0924e0c1743fd0b394615c60 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 21 Jan 2025 20:36:27 -0500 Subject: [PATCH 01/44] bert log global_norm in training step [pr] (#8708) * bert log global_norm in training step [pr] and minor cleanups * .item() --- examples/mlperf/model_train.py | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index ef92c98a69..cbd9396f75 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -572,7 +572,8 @@ def train_rnnt(): pass @TinyJit -def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): +def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, + masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): optimizer.zero_grad() lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) @@ -588,18 +589,15 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te optimizer.step() scheduler.step() - return loss.realize() + return loss.realize(), global_norm.realize() @TinyJit -def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): +def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, + masked_lm_weights:Tensor, next_sentence_labels:Tensor): lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) - masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) - return { - "masked_lm_accuracy": masked_lm_accuracy.realize(), - "next_sentence_accuracy": seq_relationship_accuracy.realize(), - "masked_lm_loss": masked_lm_loss.realize(), - "next_sentence_loss": next_sentence_loss.realize() - } + masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \ + model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize() def train_bert(): # NOTE: pip install tensorflow, wandb required @@ -735,7 +733,7 @@ def train_bert(): previous_step = None if ckpt:=getenv("RESUME", ""): load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt)) - start_step = int(scheduler_wd.epoch_counter.numpy().item()) + start_step = int(scheduler_wd.epoch_counter.item()) print(f"resuming from {ckpt} at step {start_step}") if RUNMLPERF: @@ -761,7 +759,7 @@ def train_bert(): BEAM.value = TRAIN_BEAM st = time.perf_counter() GlobalCounters.reset() - loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, + loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \ train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"]) @@ -778,7 +776,7 @@ def train_bert(): dt = time.perf_counter() device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" - loss = loss.numpy().item() + loss = loss.item() cl = time.perf_counter() if BENCHMARK: step_times.append(cl - st) @@ -788,7 +786,7 @@ def train_bert(): f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, " f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS") if WANDB: - wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, + wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS}) @@ -823,12 +821,10 @@ def train_bert(): GlobalCounters.reset() st = time.time() - eval_result: dict[str, Tensor] = eval_step_bert(model, + lm_acc, clsf_acc, lm_loss, clsf_loss = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"]) - - lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item() - lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item() + lm_acc, clsf_acc, lm_loss, clsf_loss = lm_acc.item(), clsf_acc.item(), lm_loss.item(), clsf_loss.item() eval_lm_losses.append(lm_loss) eval_clsf_losses.append(clsf_loss) @@ -845,7 +841,7 @@ def train_bert(): return if getenv("RESET_STEP", 1): eval_step_bert.reset() - del eval_data, eval_result + del eval_data avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses) avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses) avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs) From 9a9079118e609a461f82c76116ed31f28c90997a Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 21 Jan 2025 22:49:19 -0500 Subject: [PATCH 02/44] envvar BERT_LAYERS [pr] (#8709) default is 24 for large --- examples/mlperf/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index d232edfd86..dd3683a734 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -203,7 +203,7 @@ def get_mlperf_bert_config(): "intermediate_size": 4096, "max_position_embeddings": 512, "num_attention_heads": 16, - "num_hidden_layers": 24, + "num_hidden_layers": getenv("BERT_LAYERS", 24), "type_vocab_size": 2, "vocab_size": 30522 } From e3d1464ba41d58fe2c1b980ca8857af74109c25b Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 22 Jan 2025 05:43:57 -0500 Subject: [PATCH 03/44] move assign preload out of schedule item [pr] (#8710) * move assign preload out of schedule item [pr] * fix that --- test/test_schedule.py | 2 +- tinygrad/engine/memory.py | 2 +- tinygrad/engine/schedule.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index ccf423d0a5..5c9564966f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1844,7 +1844,7 @@ def run_tensor_ast(r:Tensor): sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) - si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) + si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ()) run_schedule([si]) return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 99439c54b9..a9359ce259 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -47,4 +47,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule] + return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 71cc12492c..1ff5af2bf9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -60,7 +60,6 @@ class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] - assign_preloads: tuple[UOp, ...] @property def outputs(self) -> tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" @@ -84,6 +83,7 @@ class ScheduleContext: ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) + preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) becomes_map: dict[UOp, UOp] = field(default_factory=dict) # wrap tensor uops around a VIEW(BUFFER, ) @@ -211,14 +211,14 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals)) # deal with ASSIGN - assign_preloads: list[UOp] = [] if len(ctx.assigns) != 0: + assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] for x in list(sink.toposort)[::-1]: # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") # PRELOAD tells the toposort this kernel should run before ASSIGN if x.op is Ops.PRELOAD: - assign_preloads.append(x.buf_uop) + assign_preloads[x.buf_uop] = None # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous: # if it has a single view and it's equal when you shrink a contig, it's fine @@ -229,7 +229,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), - tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads))) + tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: @@ -529,7 +529,7 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu in_degree: defaultdict[ScheduleItem, int] = defaultdict(int) for si in prescheduled: # realize outputs before a parent is assigned to - parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si) + parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si) for assign in parents_assigns: graph[si].append(assign) in_degree[assign] += 1 From 2dae467b75fa8dc74746adeca2050c09b6afb45c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 22 Jan 2025 05:44:07 -0500 Subject: [PATCH 04/44] scheduler + process_replay import cleanup (#8711) --- test/external/process_replay/process_replay.py | 10 +++++----- tinygrad/engine/schedule.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 1bc597c6b9..1cfef40b9e 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -2,7 +2,7 @@ # compare kernels created by HEAD against master from collections import defaultdict import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings -from typing import Callable, List, Tuple, Union, cast +from typing import Callable, cast from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.engine.schedule import ScheduleContext, schedule_uop from tinygrad.codegen.kernel import Kernel, Opt @@ -33,15 +33,15 @@ class ProcessReplayWarning(Warning): pass def recreate_sched(ast:UOp) -> UOp: # NOTE: process replay isn't meant to actually schedule anything return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast -def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str: +def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) # NOTE: replay with the captured renderer, not the one in master - return k.opts.render(name, cast(List,k.to_program().uops)) + return k.opts.render(name, cast(list,k.to_program().uops)) # *** diff a "good" recreation against the generated version -def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: +def diff(offset:int, name:str, fxn:Callable) -> tuple[int, int]|bool: if early_stop.is_set(): return True conn = db_connection() cur = conn.cursor() @@ -95,7 +95,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: cur.close() with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) - ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs))) + ret: list[tuple[int, int]|bool] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs))) pool.close() pool.join() pool.terminate() diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1ff5af2bf9..e51e6de63d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,10 +1,10 @@ import sys, atexit, functools, pickle from collections import defaultdict, deque from dataclasses import dataclass, field -from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map -from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar +from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers +from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views +from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY from tinygrad.dtype import DType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape From 891436853de8a17c281647e09d5cd69e638589d1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 22 Jan 2025 06:36:30 -0500 Subject: [PATCH 05/44] remove buffer size check in schedule item [pr] (#8712) --- tinygrad/engine/schedule.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e51e6de63d..07cbb3b761 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -228,8 +228,7 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: # capture process replay if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) - return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), - tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) + return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: From 93fb50ce77cd07a2d194750ca17ef52a1dd89faa Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:44:31 +0300 Subject: [PATCH 06/44] allreduce: add flags (#8713) --- test/external/external_benchmark_multitensor_allreduce.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index 6867bcaa8f..92fb5ead3c 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -44,8 +44,8 @@ def main(): else: sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.") - (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True) - (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False) + (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus) + (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus) print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s") print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s") From 49b914ee691a9f2ecdc6f0f852c4a5f4fe40c03b Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 22 Jan 2025 10:32:19 -0500 Subject: [PATCH 07/44] simpler bert acc [pr] (#8714) logit.log_softmax().argmax(-1) is equivalent to logit.argmax(-1) --- extra/models/bert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/extra/models/bert.py b/extra/models/bert.py index c1eb33f85c..01a136c8e4 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -63,15 +63,15 @@ class BertForPretraining: def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): valid = masked_lm_ids != 0 - masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1) - masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid + masked_lm_predictions = prediction_logits.argmax(-1) + masked_lm_correct = (masked_lm_predictions == masked_lm_ids) * valid masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) - seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1) - seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels) + seq_relationship_predictions = seq_relationship_logits.argmax(-1) + seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) - return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss + return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info From 907dfa0e82539b9fce3f1f3489a47213b38af1f6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 22 Jan 2025 13:25:22 -0500 Subject: [PATCH 08/44] image buffer realization spec [pr] (#8420) * image buffer realization spec [pr] * redo the spec * work --- test/test_image_dtype.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 62fcb4a443..a7fd2ab94f 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -134,5 +134,39 @@ class TestImageDType(unittest.TestCase): print(lst) assert not np.any(np.isnan(lst)) +@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU") +class TestImageRealization(unittest.TestCase): + def test_image_dtype_expand(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize() + self.assertEqual(it_expanded.dtype, dtypes.float32) + + def test_image_dtype_expand_and_back(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)) + it2 = it_expanded.sum(3).realize() + self.assertEqual(it2.dtype, dtypes.imagef((9,27,4))) + + def test_image_alu_children(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous() + alu1 = it_expanded+1 + alu2 = it_expanded.sum(3) + it_expanded.realize() + # NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image + self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4))) + alu1.realize() + self.assertEqual(alu1.dtype, dtypes.float32) + # alu2 is back in image because it fits the dtype again + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + alu2.realize() + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + if __name__ == '__main__': unittest.main() From af65331b76650435265438bfe418dc8b2fbbf0e8 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 22 Jan 2025 22:00:05 -0500 Subject: [PATCH 09/44] update beam params for bert green [pr] (#8726) increase BEAM_UPCAST_MAX and BEAM_LOCAL_MAX to default and matched red. 3% faster step --- .../benchmarks/bert/implementations/tinybox_green/dev_beam.sh | 2 +- .../benchmarks/bert/implementations/tinybox_green/dev_run.sh | 2 +- .../bert/implementations/tinybox_green/run_and_time.sh | 2 +- .../benchmarks/bert/implementations/tinybox_red/dev_beam.sh | 2 +- .../benchmarks/bert/implementations/tinybox_red/dev_run.sh | 2 +- .../benchmarks/bert/implementations/tinybox_red/run_and_time.sh | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 99b99f7e89..41cc7d8050 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 70a3b6a6cb..bde86667b9 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index a213f4d682..18ca38d147 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -5,7 +5,7 @@ export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index 08ffd354b3..5bd4f06778 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index c42c1f65b6..530b490ebf 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index d6ff9fd2cc..4d993fc1c2 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -5,7 +5,7 @@ export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" From 6cb74bb630b3c824b6a8ceeddda7d305dc30a249 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 23 Jan 2025 01:28:07 -0500 Subject: [PATCH 10/44] fix using clone with shrink [pr] (#8724) * fix using clone with shrink [pr] * remove extra arg, add test_clone_with_shrink_realized --- test/test_tensor.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/test/test_tensor.py b/test/test_tensor.py index 0e9c20d022..d73e55e36a 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase): def test_clone(self): a = Tensor.rand(16, 16).realize() - self.assertIsNot(a.lazydata, a.clone().lazydata) - np.testing.assert_allclose(a.numpy(), a.clone().numpy()) + b = a.clone() + np.testing.assert_allclose(a.numpy(), b.numpy()) + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) a = Tensor.rand(16, 16).mul(5.0).add(5.0) - self.assertIsNot(a.lazydata, a.clone().lazydata) - np.testing.assert_allclose(a.numpy(), a.clone().numpy()) + b = a.clone() + np.testing.assert_allclose(a.numpy(), b.numpy()) + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_clone_with_shrink(self): - a = Tensor.empty(16, 16) - self.assertIsNot(a.lazydata, a.clone().lazydata) + a = Tensor.rand(16, 16) + b = a.shrink(((2, 10), None)).clone() + b.realize() + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) - b = a.shrink(((2, 10), None)) - self.assertIsNot(b.lazydata, b.clone().lazydata) + def test_clone_with_shrink_realized(self): + a = Tensor.rand(16, 16).realize() + b = a.shrink(((2, 10), None)).clone() + b.realize() + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_clone_with_grad(self): a = Tensor.rand(16, 16, requires_grad=True) From 07ec99001a97098bdb0cd5ab484137eddf3d16b5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 23 Jan 2025 04:29:30 -0500 Subject: [PATCH 11/44] keep VIEW in big_sink + copy of buffer view spec [pr] (#8727) * keep views in sink [pr] * tests * things from the gpt2 bug --- test/test_schedule.py | 61 ++++++++++++++++++++++++++++++++++++- tinygrad/engine/schedule.py | 4 +-- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 5c9564966f..e3682cd1c3 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,7 +14,7 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp from tinygrad.codegen.kernel import verify_ast from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule @@ -2269,6 +2269,54 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.empty(4).lazydata check_schedule(a.clone(), 1, filter_sink=False) + # NOTE: moving copy before view might change this + def test_shrink_copy(self): + a = Tensor.arange(4) + view = a.shrink(((0, 2),)) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 2) + self.assertListEqual(b.tolist(), [0, 1]) + + def test_expanded_copy(self): + a = Tensor.arange(2) + view = a.reshape(2, 1).expand(2, 2) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 4) + self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) + + def test_permuted_copy(self): + a = Tensor.arange(4) + b = a.reshape(2, 2).permute(1, 0) + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_on_disk(self): + with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") + b = a.reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_after_shrink(self): + a = Tensor.arange(5) + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + # NOTE: disk permute must come after COPY + # TODO: this is wrong because of the permute + @unittest.expectedFailure + def test_permute_after_shrink_on_disk(self): + with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) @@ -2377,6 +2425,17 @@ class TestContiguous(unittest.TestCase): b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) + def test_view_does_not_realize(self): + a = Tensor.empty(4) + b = a.expand((4, 4)) + check_schedule(b, 0) + self.assertEqual(b.lazydata.base.buffer.size, 4) + + def test_contiguous_view_realizes(self): + a = Tensor.empty(4) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) + self.assertEqual(b.lazydata.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): # the simplest case, if we create a new BUFFER for this UOp diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 07cbb3b761..d8236d47bf 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -388,10 +388,10 @@ sym = symbolic_simple+PatternMatcher([ # support for using a contiguous permuted view instead of the parent view if one exists (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER/VIEW from SINK + # remove CONST/BIND/BUFFER from SINK (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) # ** this decides which ops get realized From e4512baea467eb12a878f307cad9869c6cf96ce1 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:49:37 +0300 Subject: [PATCH 12/44] am: cleanup mm (#8730) * am: cleanup mm * cle * ops * entries --- test/external/external_fuzz_ampt.py | 6 ++-- test/external/external_test_am.py | 10 ++++--- tinygrad/runtime/support/am/amdev.py | 42 +++++++++++++--------------- tinygrad/runtime/support/am/ip.py | 3 ++ 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/test/external/external_fuzz_ampt.py b/test/external/external_fuzz_ampt.py index 72f8b3801a..f3d4c88e86 100644 --- a/test/external/external_fuzz_ampt.py +++ b/test/external/external_fuzz_ampt.py @@ -25,7 +25,7 @@ class AMPTFuzzer: _vaddr = va.va_addr + _offset for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) self.d.vram[pte['paddr']] = pattern # Mark this page assert pte['valid'] == 1 @@ -41,7 +41,7 @@ class AMPTFuzzer: frags_l = list(ctx.next(contig_range)) for f_offset, f_pt, f_pte_idx, f_n_ptes, f_pte_covers in frags_l: for j in range(f_n_ptes): - f_pte = helper_read_entry_components(f_pt.get_entry(f_pte_idx + j)) + f_pte = helper_read_entry_components(f_pt.entries[f_pte_idx + j]) assert f_pte['valid'] == 1 assert f_pte['paddr'] == start_paddr+f_offset+j*f_pte_covers, f"paddr {f_pte['paddr']:#x} not {start_paddr+f_offset+j*f_pte_covers:#x}" @@ -53,7 +53,7 @@ class AMPTFuzzer: def verify_memory(self, pages, pattern: int) -> bool: for _offset, _pt, _pte_idx, _n_ptes, _pte_covers in pages: for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) if self.d.vram[pte['paddr']] != pattern: return False if pte['valid'] == 0: return False diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index 985554623e..b526d7cf5c 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -3,7 +3,9 @@ from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraver from tinygrad.helpers import mv_address class FakeGMC: - def __init__(self): self.vm_base = 0x0 + def __init__(self): + self.vm_base = 0x0 + self.address_space_mask = (1 << 44) - 1 def flush_tlb(self, *args, **kwargs): pass class FakePCIRegion: @@ -72,7 +74,7 @@ class TestAMPageTable(unittest.TestCase): for tup in results: _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == va + _offset + i * _pte_covers, f"Expected paddr {pte['paddr']:#x} to be {va + _offset + i * _pte_covers:#x}" assert pte['valid'] == 1 @@ -81,7 +83,7 @@ class TestAMPageTable(unittest.TestCase): for tup in results: _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == 0 assert pte['valid'] == 0 @@ -113,7 +115,7 @@ class TestAMPageTable(unittest.TestCase): for tup in ctx.next(0x100000): _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == 0xdead0000 + _offset + i * _pte_covers, f"paddr {pte['paddr']:#x} not {0xdead0000 + _offset + i * _pte_covers:#x}" assert pte['valid'] == 1 diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 37b43d9c7b..328a1280fe 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -102,19 +102,16 @@ class AMFirmware: class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 class AMPageTableEntry: - def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv + def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv - def set_table(self, entry_id, pte:AMPageTableEntry, valid=True): - self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) + def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True): + assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}" - def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True): - f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \ - | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if self.lv != am.AMDGPU_VM_PTB else 0) \ + f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \ + | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \ | ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \ | (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0) - self.view[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f - - def get_entry(self, entry_id): return self.view[entry_id] + self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f class AMPageTableTraverseContext: def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False): @@ -126,22 +123,23 @@ class AMPageTableTraverseContext: def level_down(self): pt, pte_idx, _ = self.pt_stack[-1] - if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID: - assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" - child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) - else: + if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0: assert self.create_pts, "Not allowed to create new page table" - pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) + pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True) + entry = pt.entries[pte_idx] + + assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" + child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) return self.pt_stack[-1] def _try_free_pt(self) -> bool: pt, _, _ = self.pt_stack[-1] - if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)): + if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)): self.adev.mm.pfree(pt.paddr) parent_pt, parent_pte_idx, _ = self.pt_stack[-2] - parent_pt.set_page(parent_pte_idx, 0x0, valid=False) + parent_pt.set_entry(parent_pte_idx, 0x0, valid=False) return True return False @@ -156,7 +154,7 @@ class AMPageTableTraverseContext: if self.create_pts: while pte_covers > size: pt, pte_idx, pte_covers = self.level_down() else: - while pt.lv!=am.AMDGPU_VM_PTB and (pt.get_entry(pte_idx)&am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down() + while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down() entries = min(size // pte_covers, 512 - pte_idx) assert entries > 0, "Invalid entries" @@ -181,10 +179,10 @@ class AMMemoryManager: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True) for paddr, psize in paddrs: for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize): - frag = 0 if pte_covers == 0x1000 else 0x9 for pte_off in range(pte_cnt): - assert pt.get_entry(pte_idx + pte_off) & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.get_entry(pte_idx + pte_off):#x}" - pt.set_page(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, frag=frag, valid=True) + assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}" + pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, + uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True) # Invalidate TLB after mappings. self.adev.gmc.flush_tlb(ip='GC', vmid=0) @@ -197,8 +195,8 @@ class AMMemoryManager: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True) for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size): for pte_id in range(pte_idx, pte_idx + pte_cnt): - assert pt.get_entry(pte_id) & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.get_entry(pte_id):#x}" - pt.set_page(pte_id, paddr=0x0, valid=False) + assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}" + pt.set_entry(pte_id, paddr=0x0, valid=False) @staticmethod def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align)) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 2b08798825..f29cf3e2e1 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -25,6 +25,9 @@ class AM_GMC(AM_IP): self.vm_base = self.adev.mm.va_allocator.base self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1 + # GFX11 has 44-bit address space + self.address_space_mask = (1 << 44) - 1 + self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.hub_initted = {"MM": False, "GC": False} From 8e5bd0cd7a2eacac24ab3e8d188bff5039113404 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 23 Jan 2025 10:21:38 -0500 Subject: [PATCH 13/44] fix buffer init and skip test_swizzle_failure_permute [pr] (#8732) * fix buffer init and skip test_swizzle_failure_permute [pr] * replace preload with just load * add --- test/test_schedule.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index e3682cd1c3..d852d2345e 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1940,19 +1940,19 @@ class TestSwizzle(unittest.TestCase): ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) - @unittest.expectedFailure + @unittest.skip("this swizzle can't be decided after the ADD") def test_swizzle_failure_permute(self): sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(Ops.WHERE, dtypes.float, arg=None, src=( x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( @@ -1971,13 +1971,13 @@ class TestSwizzle(unittest.TestCase): UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), x15,)), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), x10,)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)) ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) From 04846b91aadd3d52f2a128a397afe0b963bedebe Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Fri, 24 Jan 2025 02:18:54 +0800 Subject: [PATCH 14/44] reorder and categorize onnx_ops (#8731) * new order * remove a todo * constant node is definitely requires_grad false * one new line spacing * property and graph * oops linter --- extra/onnx_ops.py | 643 ++++++++++++++++++++++------------------------ 1 file changed, 312 insertions(+), 331 deletions(-) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 4f745e680b..165de3154b 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -1,41 +1,15 @@ import functools, io, math from typing import cast, Literal from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.dtype import ImageDType, dtypes, DType from tinygrad.helpers import prod, flatten, make_tuple from extra.onnx import dtype_parse, _cached_to_python_const import numpy as np -# **************** Free Ops **************** - +# ***** Property/Graph Ops ***** def Identity(x:Tensor): return x -# TODO: fix buffer_parse -def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype) -def Sub(x:Tensor|int, other:Tensor): return x - other # some test has input as int -def Less(x:Tensor,y:Tensor): return x < y -def LessOrEqual(x:Tensor,y:Tensor): return x <= y -def Greater(x:Tensor,y:Tensor): return x > y -def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y -def Equal(x:Tensor,y:Tensor): return x == y -def BitwiseNot(x:Tensor): return ~x -def BitwiseOr(x:Tensor, y:Tensor): return x | y -def BitwiseAnd(x:Tensor, y:Tensor): return x & y -def BitwiseXor(x:Tensor, y:Tensor): return x ^ y -def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) -def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) -def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - -# **************** Simple Ops **************** - -# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py -def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype) - -def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, - value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None, - value_string:str|None=None, value_strings:list[str]|None=None): +def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): if value is not None: return value if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) @@ -44,21 +18,79 @@ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float: if value_string is not None or value_strings is not None and sparse_value is not None: raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') +def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) + +def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] + if pixel_format == "RGB": return Tensor(np.array(img)) + if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + +def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + +def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) +def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) +def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) + +def Size(data:Tensor): return data.numel() +def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + +# ***** Unary Ops (math) ***** +def Not(x:Tensor): return x.logical_not() +def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): + return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) + +# ***** Unary Ops (activation) ***** +def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) +def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) +Softmax = {1:Softmax_1, 13:Softmax_13} def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) +def FastGelu(x:Tensor, bias:Tensor|None=None): + # this is tanh approximated + return (x + bias).gelu() if bias is not None else x.gelu() # TODO: fix this def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE + slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope return (X > 0).where(X, X * slope) def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) -def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) -def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) -Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) -def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): # noqa: A002 - return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) +def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() +# ***** Unary Ops (broadcasted) ***** +def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) +def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int +def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) +def Less(x:Tensor,y:Tensor): return x < y +def LessOrEqual(x:Tensor,y:Tensor): return x <= y +def Greater(x:Tensor,y:Tensor): return x > y +def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y +def Equal(x:Tensor,y:Tensor): return x == y +def And(x:Tensor,y:Tensor): return (x==y).where(x, False) +def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) +def BitwiseAnd(x:Tensor,y:Tensor): return x & y +def BitwiseOr(x:Tensor,y:Tensor): return x | y +def BitwiseXor(x:Tensor,y:Tensor): return x ^ y +def BitwiseNot(x:Tensor): return ~x + +# ***** Casting Ops ***** +# TODO: saturate +def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) +def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + +# ***** Reduce Ops ***** +def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) +def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) +def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) +def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) @@ -80,27 +112,27 @@ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_wit return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() +def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) +def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) -def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) -def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) -def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) - -def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) -def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) -def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) -def Size(data:Tensor): return data.numel() -def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) +# ***** Movement Ops ***** def Reshape(data:Tensor, shape:list[int], allowzero:int=0): return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) +def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) -def And(x:Tensor, y:Tensor): return (x==y).where(x, False) -def Or(x:Tensor, y:Tensor): return (x==y).where(x, True) -def Not(x:Tensor): return x.logical_not() +def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) -def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) +# TODO: add test for when axes is None +def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) +def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) +def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) +def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): axes = axes or list(range(data.ndim)) steps = steps or [1]*data.ndim @@ -113,83 +145,9 @@ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0) if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] return data.split(split, axis) -# TODO: add test for when axes is None -def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) -def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - -def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - -def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - -def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) -def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) - -def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) - -# **************** Complex Ops **************** - -def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - -def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - -def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - -# TODO: this is copied from tinygrad/nn/__init__.py -# spatial is from opset 7 and has since been removed -def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon).rsqrt() - return X.batchnorm(scale, B, input_mean, invstd) - -def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - axis = tuple(range(2, x.ndim)) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) - -def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() - -def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) - -# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) -def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])])) - -AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] -# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) -def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] - +def _onnx_pads_to_tiny_pads(pads): + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): value = constant_value or value @@ -198,6 +156,21 @@ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[ for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) +def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + +# ***** Processing Ops ***** +AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] +def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) @@ -206,25 +179,22 @@ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad) + return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), + ceil_mode=ceil_mode, count_include_pad=count_include_pad) def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, storage_order:int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode) + ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) # tests expect indices with int64 dtype # TODO: if there are repeated values, this is wrong indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad) - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads)) + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) -# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose -# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, strides:list[int]|int=1): @@ -247,80 +217,28 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shap if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") return ret.pad(_onnx_pads_to_tiny_pads(pads)) -def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) -def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) +def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) +def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -# Reimplemented here because you need legacy RNG for passing ONNX tests. -def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask -# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx -def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) -Dropout = {6:Dropout_6, 7:Dropout_7} +def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret -def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) +def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) -def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) +def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) -def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) - -def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - -def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - -def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] -def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - -def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) -def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - -def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) -def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) +def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): def _apply_nearest_mode(index: Tensor, input_dim, mode: str): if mode == "round_prefer_floor": index = (index - 0.5).ceil() elif mode == "round_prefer_ceil": index = (index + 0.5).floor() @@ -377,116 +295,43 @@ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, si X = X.gather(i, low).lerp(X.gather(i, high), perc) if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X - -def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - -def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(depth[0] if isinstance(depth, list) else depth) - indices = (indices < 0).where(indices+depth, indices) - return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - -def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - if axis < 0: axis += inp.ndim - con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - -def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated -def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): - if axis < 0: axis += x.ndim - if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) - if block_size == 0: - shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) - return scale.reshape(shape), zero_point.reshape(shape) - return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) +# ***** Neural Network Ops ***** +# TODO: try to factor out common implementations for these ops +# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 +def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() -def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous() - -def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - -def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor): - assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops" - y = (y / y_scale + y_zero_point).round() - return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype) - -def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1, - kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1): - x = x.int() - x_zero_point - w = w.int() - w_zero_point - y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides) - y_scale = y_scale / (x_scale * w_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point:Tensor|int) -> Tensor: - a = a.int() - a_zero_point - b = b.int() - b_zero_point - y = Tensor.matmul(a, b, acc_dtype=dtypes.int32) - y_scale = y_scale / (a_scale * b_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, - auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None, - pads: int | list[int] = 0, strides: int | list[int] = 1) -> Tensor: - x_int = x.int() - x_zero_point - w_int = w.int() - w_zero_point - return Conv(x_int, w_int, B, auto_pad, dilations, group, kernel_shape, pads, strides) - -def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: - A_int = A.int() - a_zero_point - B_int = B.int() - b_zero_point - return Tensor.matmul(A_int, B_int, acc_dtype=dtypes.int32) - -# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py -def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] - if pixel_format == "RGB": return Tensor(np.array(img)) - if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - -def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - -# **************** com.microsoft Ops **************** + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + invstd = (input_var + epsilon).rsqrt() + return X.batchnorm(scale, B, input_mean, invstd) +def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + axis = tuple(range(2, x.ndim)) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) +def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() +def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) +def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): x = x + skip + bias return x.layernorm(eps=epsilon) * gamma + beta, None, None, x - -def FastGelu(x:Tensor, bias:Tensor|None=None): - # this is tanh approximated - return (x + bias).gelu() if bias is not None else x.gelu() - def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): @@ -513,6 +358,45 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embeddin out = embedding_sum.layernorm(eps=epsilon) * gamma + beta return out, None, embedding_sum +def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(depth[0] if isinstance(depth, list) else depth) + indices = (indices < 0).where(indices+depth, indices) + return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + +def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) +def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + +# Reimplemented here because you need legacy RNG for passing ONNX tests. +def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. + mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) + return data * mask * (1/(1.0 - ratio)), mask +# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx +def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) +Dropout = {6:Dropout_6, 7:Dropout_7} + +def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + +def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) +def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + +def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, @@ -552,37 +436,132 @@ def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past: out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) return out, present if past is not None else out +# ***** Indexing Ops ***** +def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + +def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + x_sh = list(x.shape) + ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] +def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + +def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) +def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + +def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) +def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + +def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + if axis < 0: axis += inp.ndim + con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + +# ***** Quantization Ops ***** +def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + +def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): + if axis < 0: axis += x.ndim + if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) + if block_size == 0: + shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) + return scale.reshape(shape), zero_point.reshape(shape) + return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) + +def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + +def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + +def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + +def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point: Tensor|int, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + +def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point:Tensor|int) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - a = a.int() - a_zero_point - b = b.int() - b_zero_point - c = (a * a_scale + b * b_scale) - return _quantize_linear(c, c_scale, c_zero_point) + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last in {0, 1} - if channels_last == 1: X = X.permute(0, 2, 3, 1) - X = (X.int() - x_zero_point) * x_scale - y = GlobalAveragePool(X) - return _quantize_linear(y, y_scale, y_zero_point) + assert channels_last == 0, "unsure what this does" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) -# **************** ai.onnx.preview.training Ops **************** +def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + +def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + +# ***** Training Ops ***** # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code - -from tinygrad.nn.optim import Adam as TinyAdam -from tinygrad.nn.optim import SGD - -def onnx_training(input_group_size): - def _decorator(func): - def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): +def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] return tuple(flatten(zip(*ret))) - return __wrapper - return _decorator + return ___wrapper + return __decorator -@onnx_training(3) +@_onnx_training(3) def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): X, G, H = (i.detach() for i in inputs) grad = norm_coefficient * X + G @@ -592,9 +571,10 @@ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:flo X.assign(X.detach() - r * up) return [X, H] -@onnx_training(4) +@_onnx_training(4) def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam X, G, V, H = inputs G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches X.grad = norm_coefficient * X.detach() + G @@ -610,8 +590,9 @@ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, eps X = (1 - norm_coefficient_post) * X return [X, V, H] -@onnx_training(3) +@_onnx_training(3) def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + from tinygrad.nn.optim import SGD X, G, V = inputs G, V = G.detach(), V.detach() X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) @@ -620,6 +601,6 @@ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, opt.step() return [X, V] -def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **__): +def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): intermediate_tensors[y].backward() return tuple([t.grad for t in inputs]) From 3e987fc85674ad01390b3643f4579146847f67b5 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 23 Jan 2025 12:46:27 -0800 Subject: [PATCH 15/44] add device print with -m tinygrad.device [pr] (#8729) * add device print with -m tinygrad.device [pr] * fix linter --- tinygrad/device.py | 22 +++++++++++++++++++--- tinygrad/helpers.py | 6 +++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/tinygrad/device.py b/tinygrad/device.py index 2a20992f80..7fea67aa00 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -5,12 +5,13 @@ from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ - cpu_time_execution + cpu_time_execution, colored, Context from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes from tinygrad.renderer import Renderer # **************** Device **************** +ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"] class _Device: def __init__(self) -> None: self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] @@ -25,7 +26,7 @@ class _Device: cpn = multiprocessing.current_process().name assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}" x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \ + ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \ if (cname.lower() == x.lower() + "device")][0](ix) if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") self._opened_devices.add(ix) @@ -33,7 +34,7 @@ class _Device: @property def default(self) -> Compiled: return self[self.DEFAULT] def get_available_devices(self) -> Iterator[str]: - for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]: + for device in ALL_DEVICES: with contextlib.suppress(Exception): yield self[device].device @functools.cached_property def DEFAULT(self) -> str: @@ -314,3 +315,18 @@ if PROFILE: from tinygrad.ops import launch_viz launch_viz("PROFILE", fn) + +if __name__ == "__main__": + for device in ALL_DEVICES: + try: + _ = Device[device].device + try: + from tinygrad import Tensor + with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() + if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") + result = colored("PASS", "green") + except Exception as e: + result = f"{colored('FAIL', 'yellow')} {e}" + except Exception as e: + result = f"{colored('FAIL', 'red')} {e}" + print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}") \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 090b9178cc..d1eedb5eb5 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -109,6 +109,7 @@ USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) +CACHELEVEL = ContextVar("CACHELEVEL", 2) @dataclass(frozen=True) class Metadata: @@ -165,7 +166,6 @@ class Profiling(contextlib.ContextDecorator): cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad") CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db"))) -CACHELEVEL = getenv("CACHELEVEL", 2) VERSION = 17 _db_connection = None @@ -186,7 +186,7 @@ def diskcache_clear(): cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"])) def diskcache_get(table:str, key:Union[dict, str, int]) -> Any: - if CACHELEVEL == 0: return None + if CACHELEVEL < 1: return None if isinstance(key, (str,int)): key = {"key": key} conn = db_connection() cur = conn.cursor() @@ -199,7 +199,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any: _db_tables = set() def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False): - if CACHELEVEL == 0: return val + if CACHELEVEL < 1: return val if isinstance(key, (str,int)): key = {"key": key} conn = db_connection() cur = conn.cursor() From eb77488f85652e83dc2181294fbb338b1ea2e398 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 23 Jan 2025 19:06:05 -0500 Subject: [PATCH 16/44] update llama3 70B to use R1 (#8733) --- examples/llama3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/llama3.py b/examples/llama3.py index e331573d1a..c24ec6ea2a 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -247,11 +247,11 @@ if __name__ == "__main__": fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr") args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr") elif args.size == "70B": - subdir = "Llama-3.1-Nemotron-70B-Instruct-HF" - args.model = fetch("https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir) + subdir = "DeepSeek-R1-Distill-Llama-70B" + args.model = fetch("https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir) fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=subdir) - for i in range(30): - fetch(f"https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model-{i+1:05d}-of-00030.safetensors?download=true", f"model-{i+1:05d}-of-00030.safetensors", subdir=subdir) + for i in range(17): + fetch(f"https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model-{i+1:05d}-of-000017.safetensors?download=true", f"model-{i+1:05d}-of-000017.safetensors", subdir=subdir) assert args.model is not None, "please provide --model option" From e82ba1454b0ddbe9e994c7416dc223bd5334b99a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:28:55 +0900 Subject: [PATCH 17/44] MultiLazyBuffer is UOp [pr] (#8662) * MultiLazyBuffer is UOp [pr] * this is new mlb * this is the idea * progress * multitensor works * more movement ops * this * MultiLazyBuffer is UOp * cleanups * multi axis * fix more tests * work * not that * add multi grad and move shard to ops * mops not views * no double contig * sweet, all mt tests passing * port old logic * remove lbs * fix realized * whitespace * assign tweak * test_assign_kv_cache_multi passes * fix is_realized * fix JIT for multi * just a few more lines i'll pay them back soon i swear please bro just a few more * no split reduceop for multi --- .github/workflows/test.yml | 4 +- examples/hlb_cifar10.py | 3 - test/test_multitensor.py | 81 +++++++++---- tinygrad/engine/jit.py | 5 +- tinygrad/engine/schedule.py | 3 +- tinygrad/gradient.py | 2 +- tinygrad/multi.py | 228 ++++++++++++++++++------------------ tinygrad/nn/state.py | 7 +- tinygrad/ops.py | 67 +++++++++-- tinygrad/tensor.py | 85 ++++++-------- tinygrad/viz/serve.py | 2 +- 11 files changed, 277 insertions(+), 210 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be82930b62..c95325349e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -243,8 +243,8 @@ jobs: run: | PYTHONPATH="." python test/external/fuzz_shapetracker.py PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - - name: Repo line count < 11000 lines - run: MAX_LINE_COUNT=11000 python sz.py + - name: Repo line count < 11100 lines + run: MAX_LINE_COUNT=11100 python sz.py testopencl: strategy: diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 78c59bdb18..d2008ef87d 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod -from tinygrad.multi import MultiLazyBuffer cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618] cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628] @@ -35,8 +34,6 @@ class UnsyncedBatchNorm: self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False) def __call__(self, x:Tensor): - if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices - xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32) batch_mean, batch_invstd = self.calc_stats(xr) ret = xr.batchnorm( diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b34baced75..25f863568d 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,11 +1,11 @@ import unittest, functools, random from typing import List -from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import Ops +from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable +from tinygrad.ops import Ops, UOp from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule -from tinygrad.multi import all_reduce, MultiLazyBuffer +from tinygrad.multi import all_reduce import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.device import is_dtype_supported @@ -30,7 +30,7 @@ N = 128 def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0)) b.realize() return aa, b @@ -39,7 +39,7 @@ class TestMultiTensor(unittest.TestCase): def test_to(self): X = Tensor.ones(256).contiguous().realize() X.to_(devices_2) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (256,) (X + X).realize() @@ -52,7 +52,7 @@ class TestMultiTensor(unittest.TestCase): def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (128,) (X + X).realize() @@ -218,9 +218,9 @@ class TestMultiTensor(unittest.TestCase): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): - a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) with Context(RING=2): - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() @@ -356,8 +356,8 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() shard_output = m(fake_image_sharded).log_softmax().realize() - assert shard_output.lazydata.lbs[0].shape == (1, 1000) - assert shard_output.lazydata.lbs[1].shape == (1, 1000) + assert shard_output.lazydata.src[0].shape == (1, 1000) + assert shard_output.lazydata.src[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) @@ -386,12 +386,35 @@ class TestMultiTensor(unittest.TestCase): GlobalCounters.reset() optimizer.zero_grad() shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1) - assert shard_output.lazydata.axis is None shard_output.backward() shard_grad = m.conv1.weight.grad.numpy() # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) + def test_assign_kv_cache_multi(self): + bsz, max_context = 2, 8 + + class Attn: + @TinyJit + def __call__(self, xk:Tensor, start_pos:UOp): + seqlen = xk.shape[1] + if not hasattr(self, "cache_k"): + self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize() + keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk + self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize() + + attn = Attn() + xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous() + attn(xk, 0) + for i in range(3,6): + # copied from LLaMA + start_pos = Variable("start_pos", 1, max_context).bind(i) + xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous() + attn(xk, start_pos) + + out = attn.cache_k.flatten().numpy() + np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.]) + def test_multi_tensor_jit_param(self): @TinyJit def jf(a, b) -> Tensor: @@ -532,13 +555,13 @@ class TestMultiTensor(unittest.TestCase): t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: - assert t.lazydata.axis == 1 np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) + assert t.lazydata.axis == 1 # test shape-one axis t5 = t4.reshape((26, 1, 105)) - assert t5.lazydata.axis == 2 np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) + assert t5.lazydata.axis == 2 # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) @@ -553,7 +576,7 @@ class TestMultiTensor(unittest.TestCase): # test no left join with self.assertRaises((AssertionError, ValueError)): - t0.reshape((26*15,7)) + t0.reshape((26*15,7)).schedule() @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): @@ -588,6 +611,7 @@ class TestMultiTensor(unittest.TestCase): with self.assertRaises(AssertionError): # don't allow assigns that change axes t_none.assign(t_zero) + t_none.schedule() def test_init_rand_with_multiple_devices_fail(self): # init rand with multi device is not allowed @@ -635,7 +659,7 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) - assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs)) + assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src)) def test_rand_like_none_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) @@ -718,7 +742,7 @@ class TestMultiTensor(unittest.TestCase): devices = (d0, d1, d2, d3) t = Tensor.zeros(16, 16).contiguous() t.shard_(devices, axis=0).realize() - assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs]) + assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src]) @unittest.skip("this is unreliable on OSX") def test_clone(self): @@ -774,25 +798,31 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): with self.assertRaises(AssertionError): # sharded axis shrink on non-device boundry is not allowed a = t.shrink(((0, 3), (0, 8))) + a.schedule() with self.assertRaises(AssertionError): # cannot shrink sharded and non-sharded axis at the same time a = t.shrink(((0, 2), (2, 4))) + a.schedule() a = t.shrink(((0, 2), (0, 8))) + a.schedule() assert a.shape == (2, 8) - assert a.lazydata.real == [True, False, False, False] + assert a.lazydata.real == (True, False, False, False) with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time p = a.pad(((0, 6), (0, 1))) + p.schedule() with self.assertRaises(AssertionError): # can only pad to whole axis p = a.pad(((1, 5), (0, 0))) + p.schedule() p = a.pad(((0, 6), (0, 0))) + p.schedule() assert p.shape == (8, 8) - assert p.lazydata.real == [True, True, True, True] + assert p.lazydata.real == (True, True, True, True) @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16])) def test_ops(self, dtype): @@ -804,8 +834,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((0+2*i,2+2*i),None)) b = Tensor(t.numpy()[0+2*i:2+2*i]) assert a.shape == b.shape == (2, 8) - assert a.lazydata.real == [i==j for j in range(4)] np.testing.assert_allclose(a.numpy(), b.numpy()) + assert a.lazydata.real == tuple(i==j for j in range(4)) # cast np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) @@ -865,18 +895,20 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((2, 4), None)) b = t.shrink(((6, 8), None)) - self.assertEqual(a.lazydata.real, [False, True, False, False]) - self.assertEqual(b.lazydata.real, [False, False, False, True]) na = t.numpy()[2:4] nb = t.numpy()[6:8] np.testing.assert_equal(a.numpy(), na) np.testing.assert_equal(b.numpy(), nb) + self.assertEqual(a.lazydata.real, (False, True, False, False)) + self.assertEqual(b.lazydata.real, (False, False, False, True)) with self.assertRaises(AssertionError): # cannot add directly c = a + b + c.schedule() c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) - self.assertEqual(c.lazydata.real, [True, True, True, True]) + c.realize() + self.assertEqual(c.lazydata.real, (True, True, True, True)) expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) np.testing.assert_equal(c.numpy(), expected) @@ -937,8 +969,9 @@ class TestBatchNorm(unittest.TestCase): def __call__(self, x:Tensor): bn_ts = [] - for bound, bn in zip(x.lazydata.bounds, self.bns): - xi = x.shrink((bound, None, None, None)) + each = x.shape[0]//len(self.bns) + for i, bn in enumerate(self.bns): + xi = x.shrink(((each*(i), each*(i+1)), None, None, None)) bni = bn(xi) bn_ts.append(bni) return bn_ts[0].cat(*bn_ts[1:]) diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index f718ef311e..afb9d726a4 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType -from tinygrad.ops import UOp, Variable, sym_infer +from tinygrad.ops import UOp, Variable, sym_infer, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner @@ -194,7 +194,8 @@ def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] if tensors: Tensor.realize(*tensors) - lbs: list[UOp] = flatten([t.lazydata.lbs for t in tensors]) + # TODO: should we be unpacking multi here? + lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors]) input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None] assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d8236d47bf..7be4360ac1 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -104,6 +104,7 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one + assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer @@ -418,7 +419,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not b.device.startswith("DISK"): return None + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 756a9c9785..f64f443858 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([ # TODO: this cast can be removed by putting the casts around the EXPAND (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)), - + (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # there's no gradient for bitcast (UPat(Ops.BITCAST), lambda ctx: (None,)), ]) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index e6e165bb25..8a7b9d04a0 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -1,8 +1,7 @@ from __future__ import annotations import functools, itertools, operator from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv -from tinygrad.dtype import DType -from tinygrad.ops import Ops, MathTrait, UOp, sint +from tinygrad.ops import Ops, UOp, sint def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]: assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}" @@ -40,133 +39,130 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}") return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))] -class MultiLazyBuffer(MathTrait): - def __init__(self, lbs:list[UOp], axis:int|None, real:list[bool]|None=None): - assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" - assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}" - self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs) +# ***** multi functions ***** - @property - def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) +from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites - @property - def size(self): return sum(x.size for x in self.real_lbs) +def alu_multi(root:UOp): + msrcs = root.src + assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" + assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - @property - def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r] + # NOTE: they all have to share an axis, we always choose [-1] + axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) + srcs:list[list[UOp]] = [] + not_all_real = not all(all(mlb.real) for mlb in msrcs) + new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else msrcs[0].real + assert any(new_real), "output contains no real lb" + for mlb in msrcs: + if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) + else: + assert axis is not None and bounds is not None + if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds)) + else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds)) + new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(root.op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} + # NOTE: const dtype should match real + new_dtype = next(iter(new_real_lbs.values())).dtype + new_lbs = [new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))] + return UOp.multi(*new_lbs, axis=axis, real=new_real) - @property - def bounds(self): - if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") - return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0))) +def reduce_multi(root:UOp, multi:UOp): + op, axis = root.arg + if multi.axis is not None and multi.axis in axis: + # all-reduce on sharded axes + reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] + # if all partitions are real, do all_reduce + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None) + # only one partition is real, keep it + return UOp.multi(*reduced_parts, axis=None, real=multi.real) + # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real) - def __repr__(self): return f"" +def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: + return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) - def copy_to_device(self, device:str) -> UOp: - # if we already have a copy on the device, return that - if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) - # copy lbs to device, pad to final shape, and sum - llbs:list[UOp] = [] - for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): - if not real: continue - pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape))) - llbs.append(lb.copy_to_device(device).pad(pad_arg)) - return functools.reduce(operator.add, llbs) +def reshape_multi(root:UOp, multi:UOp): + arg = root.arg + if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real) + assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" + arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1 + assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ + f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" + lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] + return UOp.multi(*lbs, axis=new_axis, real=multi.real) - # passthroughs - @property - def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) - def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real) - def bitcast(self, dtype:DType): return MultiLazyBuffer([x.bitcast(dtype) for x in self.lbs], self.axis, self.real) - def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) - def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) - def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) - def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real) - def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real) - @property - def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort} +def expand_multi(root:UOp, multi:UOp): + # NOTE: this assert isn't needed, sharded axis can have dim 1 + assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}" + return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real) - # elementwise is simple - def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer: - msrcs = (self,)+in_srcs - assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}" - assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" +def pad_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}" + # pad on shard axis -> fill others with zeros and set real to all True + if multi.axis is not None and root.arg[multi.axis] != (0,0): + # pad back to whole axis, remove real mask + assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time" + dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)] + assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" + return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis) + return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - # NOTE: they all have to share an axis, we always choose [-1] - axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) - srcs:list[list[UOp]] = [] - not_all_real = not all(all(mlb.real) for mlb in msrcs) - new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real - assert any(new_real), "output contains no real lb" - for mlb in msrcs: - if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs) - else: - assert axis is not None and bounds is not None - if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds)) - else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) - new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} - # NOTE: const dtype should match real - new_dtype = next(iter(new_real_lbs.values())).dtype - return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) +def permute_multi(root:UOp, multi:UOp): + # all permutes supported! + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real) - def r(self, op:Ops, axis:tuple[int, ...]) -> MultiLazyBuffer: - if self.axis is not None and self.axis in axis: - # all-reduce on sharded axes - reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)] - # if all partitions are real, do all_reduce - if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) - # only one partition is real, keep it - return MultiLazyBuffer(reduced_parts, None, self.real) - # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) +def shrink_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ + f"shrinking not supported for {root.arg=}" + if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]): + assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \ + "cannot shrink sharded and non-sharded axis at the same time" + # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real + idx = multi.bounds.index(root.arg[multi.axis]) + # zero out other lbs to not create lb reference + return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)], + axis=multi.axis, real=[i==idx for i in range(len(multi.src))]) + return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], + axis=multi.axis, real=multi.real) - # *** movement ops *** +def stride_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis" + return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - def _shape_to_single_shard(self, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: - return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) +def copy_multi(multi:UOp, device:UOp): + # if we already have a copy on the device, return that + if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg)) + # copy lbs to device, pad to final shape, and sum + llbs:list[UOp] = [] + for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds): + if not real: continue + pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape))) + llbs.append(lb.copy_to_device(device.arg).pad(pad_arg)) + return functools.reduce(operator.add, llbs) - def reshape(self, arg:tuple[sint, ...]): - if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real) - assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)" - arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1 - assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}" - lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs] - return MultiLazyBuffer(lbs, new_axis, self.real) +def assign_multi(dest:UOp, src:UOp): + assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}" + return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real) - def pad(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}" - # pad on shard axis -> fill others with zeros and set real to all True - if self.axis is not None and arg[self.axis] != (0,0): - # pad back to whole axis, remove real mask - assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time" - dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)] - assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" - return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis) - return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real) +def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real) - def expand(self, arg:tuple[sint, ...]): - # NOTE: this assert isn't needed, sharded axis can have dim 1 - assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}" - return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real) +# NOTE: this is the same pattern as Ops.UNROLL +multi_pm = PatternMatcher([ + (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi), + (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi), + (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi), + (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi), + (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), + (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi), + (UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi), + (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), + (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi), + (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), +]) - def permute(self, arg:tuple[int, ...]): - # all permutes supported! - return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real) - - def shrink(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" - if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): - assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time" - # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real - idx = self.bounds.index(arg[self.axis]) - # zero out other lbs to not create lb reference - return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))]) - return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs], - self.axis, self.real) - - def stride(self, arg:tuple[int, ...]): - assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" - return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real) +@track_rewrites(named=True) +def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v} diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 9400dd2e82..99032cfe63 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T from tinygrad.shape.view import strides_for_shape -from tinygrad.multi import MultiLazyBuffer class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): @@ -152,9 +151,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr continue if v.shape != state_dict[k].shape: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') - if isinstance((mlb:=v.lazydata), MultiLazyBuffer): - if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize() - else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize() + if isinstance(v.device, tuple): + if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize() + else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize() else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0225603cc4..cf40c1fbed 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -150,6 +150,7 @@ class Ops(FastEnum): # device DEVICE = auto() + MULTI = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} @@ -281,6 +282,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def st(self) -> ShapeTracker|None: + from tinygrad.shape.shapetracker import ShapeTracker + if self.op is Ops.MULTI: + return ShapeTracker.from_shape( + tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) @@ -294,7 +299,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # only reduce ops are allowed to change shape, everything else derives shape from sources elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg) else: shape = src_sts[0].shape - from tinygrad.shape.shapetracker import ShapeTracker return ShapeTracker.from_shape(shape) @functools.cached_property @@ -350,7 +354,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source - return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) + if self._device is None: return UOp.const(self.dtype, b) + if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None) + return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -389,7 +395,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): new_shape = unwrap(self.st).reduce(axis) # TODO: can we split symbolic shape if the reduce axis is not symbolic? - if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \ + # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi + if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, axis) @@ -410,6 +417,45 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return self.alu(Ops.CONTIGUOUS) + # *** from MultiLazyBuffer *** + + def multi(self, *more:UOp, axis:int|None, real:list[bool]|None=None): + parents = (self,)+more + assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" + return UOp(Ops.MULTI, self.dtype, parents, (axis, tuple(real if real is not None else [True]*len(parents)))) + + @property + def bounds(self): + if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") + return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0))) + + @property + def axis(self): + assert self.op is Ops.MULTI + return self.arg[0] + + @property + def real(self): + assert self.op is Ops.MULTI + return self.arg[1] + + @property + def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r] + + def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp: + if axis is None: lbs = [self] * len(devices) + else: + if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") + sz = self.shape[axis] // len(devices) + sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] + lbs, off = [], 0 + for sz in sizes: + lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape)))) + off += sz + sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] + # NOTE: this contiguous is making it impossible for the scheduler to do late const folding + return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis) + # *** from LazyBuffer *** @staticmethod @@ -426,7 +472,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) # otherwise it's just a VIEW(BUFFER) return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) - def copy_to_device(self, device:str, clone:bool=False) -> UOp: + def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp: # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) @@ -440,8 +486,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property - def lbs(self): return [self] - @property def metadata(self): return all_metadata.get(self, None) # *** uop movement ops *** @@ -470,10 +514,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @staticmethod def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property - def device(self) -> str: return unwrap(self._device) + def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device)) @functools.cached_property - def _device(self) -> Optional[str]: + def _device(self) -> Optional[str|tuple[str, ...]]: if self.op is Ops.DEVICE: return self.arg + if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src) return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: @@ -489,6 +534,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret from tinygrad.device import Buffer + assert isinstance(self.device, str), f"buffer not supported on multi {self.device}" buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property @@ -496,7 +542,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized return self.buffer if self.op is Ops.BUFFER else None @property - def is_realized(self) -> bool: return self.base.realized is not None + def is_realized(self) -> bool: + return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None # *** uop Variable stuff *** @@ -639,7 +686,7 @@ def print_uops(uops:list[UOp]): def get_location() -> tuple[str, int]: frm = sys._getframe(1) # find the real frame in the file that has the UPat, TODO: is there a better way to do this? - while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", + while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py", "lowerer.py", "cstyle.py", "linearize.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 45a9b23749..2c47162c26 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -6,10 +6,10 @@ from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Seque from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap -from tinygrad.multi import MultiLazyBuffer +from tinygrad.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element -from tinygrad.device import Device, Buffer, BufferSpec +from tinygrad.device import Device, BufferSpec from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars @@ -30,18 +30,17 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # link the found UOps back to Tensors. exit early if there's no Tensors to realize # NOTE: this uses all_tensors, but it's fast - fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)] + fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops] if len(fixed_tensors): # potentially rewrite all the discovered Tensors - sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors]) + sink = UOp.sink(*[t.lazydata for t in fixed_tensors]) new_sink = sink.substitute(applied_map) # set the relevant lazydata to the realized UOps for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): if s is ns: continue - if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src) - else: t.lazydata = ns + t.lazydata = ns # **** start with two base classes, Tensor and Function **** @@ -68,7 +67,7 @@ import tinygrad.function as F def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) - return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None) + return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np @@ -159,7 +158,7 @@ class Tensor(SimpleMathTrait): return instance def __del__(self): all_tensors.discard(weakref.ref(self)) - def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 + def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): if dtype is not None: dtype = to_dtype(dtype) if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None @@ -176,7 +175,7 @@ class Tensor(SimpleMathTrait): self._ctx: Optional[Function] = None # create a LazyBuffer from the different types of inputs - if isinstance(data, (UOp, MultiLazyBuffer)): + if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" # NOTE: this is here because LazyBuffer = UOp if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data) @@ -199,12 +198,12 @@ class Tensor(SimpleMathTrait): data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer - if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data might be on a different device - if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device) + if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device) # if device is a tuple, we should have/construct a MultiLazyBuffer - elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata + elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata else: assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" self.lazydata = data @@ -224,8 +223,8 @@ class Tensor(SimpleMathTrait): def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev def __repr__(self): - if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{ld!r}" - else: ld_repr = f"" + ld = self.lazydata + ld_repr = f"" return f"" # Python has a non moving GC, so this should be okay @@ -254,7 +253,14 @@ class Tensor(SimpleMathTrait): NOTE: A Tensor can only be scheduled once. """ - big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst])) + big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + + # TODO: move this to scheduler tensor_map pass + if any(x.op is Ops.MULTI for x in big_sink.toposort): + # multi fixup + _apply_map_to_tensors(get_multi_map(big_sink)) + big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst])) + schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map) return memory_planner(schedule), var_vals @@ -293,7 +299,6 @@ class Tensor(SimpleMathTrait): assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? if not self.lazydata.is_realized: return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) @@ -309,7 +314,8 @@ class Tensor(SimpleMathTrait): if 0 in self.shape: return memoryview(bytearray(0)) # NOTE: this realizes on the object from as_buffer being a Python object cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize() - buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized) + buf = cast(UOp, cpu.lazydata).base.realized + assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized" if self.device != "CLANG": buf.options = BufferSpec(nolru=True) return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False) @@ -405,18 +411,9 @@ class Tensor(SimpleMathTrait): print(t.shard((t.device, t.device), axis=1).lazydata) ``` """ - assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer" + assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" devices = tuple(Device.canonicalize(x) for x in devices) - if axis is None: lbs = [self.lazydata] * len(devices) - else: - axis = self._resolve_dim(axis) - if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") - sz = self.shape[axis] // len(devices) - sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] - lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)] - sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] - # NOTE: this contiguous is making it impossible for the scheduler to do late const folding - mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis) + mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None): @@ -439,7 +436,7 @@ class Tensor(SimpleMathTrait): def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): - return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None), + return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), device, dtype, **kwargs) return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @@ -750,12 +747,12 @@ class Tensor(SimpleMathTrait): ``` """ dtype = kwargs.pop("dtype", self.dtype) - if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer): + if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) contiguous = kwargs.pop("contiguous", True) - rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs] - return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) + rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src] + return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops ***** @@ -921,18 +918,15 @@ class Tensor(SimpleMathTrait): assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) rets = [] - for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)): - target_uops = [x.lazydata.lbs[i] for x in targets] - grads = compute_gradient(uop, grad, set(target_uops)) - ret = [] - for x in target_uops: - if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}") - ret.append(y) - rets.append(ret) + target_uops = [x.lazydata for x in targets] + grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) + ret = [] + for x in target_uops: + if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + ret.append(y) + rets.append(ret) # create returned Tensors - if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real), - device=t.device) for t,u in zip(targets, zip(*rets))] + return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] def _deepwalk(self) -> list[Tensor]: def _walk(node:Tensor, visited:set[Tensor]): @@ -977,8 +971,7 @@ class Tensor(SimpleMathTrait): for t, g in zip(ctx.parents, grads): if g is not None and t.requires_grad: assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \ - f"grad uop must have a path from self\ngrad uop: {t.lazydata}" + assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}" t.grad = g if t.grad is None else (t.grad + g) if not retain_graph: del t0._ctx return self @@ -1278,7 +1271,7 @@ class Tensor(SimpleMathTrait): self._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first - if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous") + if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index c95fb5f800..fdceca9c0b 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -13,7 +13,7 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"} From 07069b99886616aa53e9999b52e80e0ec5604402 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 24 Jan 2025 06:42:25 -0500 Subject: [PATCH 18/44] rename to tensor_uop [pr] (#8737) --- tinygrad/engine/schedule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7be4360ac1..9e4ff9d0d5 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -329,7 +329,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue + if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -512,7 +512,7 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: - for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) + for tensor_uop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed for k,v in tensor_map.items(): From 0814a79cb45af182bf6e5ab877e8c5ed85ae45fd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 24 Jan 2025 09:49:54 -0500 Subject: [PATCH 19/44] cleanup the merge_views upats [pr] (#8738) --- tinygrad/engine/schedule.py | 4 ++-- tinygrad/ops.py | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9e4ff9d0d5..7f6b141345 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -153,9 +153,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) -# push VIEW to stores +# push VIEW to children view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view() + # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index cf40c1fbed..c028c82745 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1316,18 +1316,19 @@ Variable = UOp ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]] -# *** uop swizzling *** +# *** UOp merge views and swizzling *** merge_views = PatternMatcher([ - (UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)), - (UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None), + # VIEW(VIEW) merges to a single VIEW + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), ]) -# push VIEW to loads +# push VIEW to parents view_left = merge_views+PatternMatcher([ - # VIEW before elementwise ops - (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), - lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), - # early merge VIEW buffer ops - (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), + # VIEW before elementwise/buffer ops + (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), + lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), + (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)), + lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ]) From 7a2223a6c692b4497e668288a92420ee1c6334ad Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:45:11 +0900 Subject: [PATCH 20/44] add merge views to ops_folding [pr] (#8051) Co-authored-by: qazal --- tinygrad/engine/schedule.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7f6b141345..58ca954abf 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -478,14 +478,12 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, # **** movement ops -remove_movement_ops = PatternMatcher([ +remove_movement_ops = merge_views+PatternMatcher([ # NOTE: movement ops are always applied to base (UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))), # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) (UPat(Ops.VIEW, name="view"), lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), - # merge one src views. - (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)), # merge unmasked const views (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), From dc10187fc0dbdafcce176bc58fdba43be56d8263 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Fri, 24 Jan 2025 20:16:19 +0300 Subject: [PATCH 21/44] am: add am_smi (#8739) * am: start monitor * cleanups * fixes * hmm * progress * cleanup --- extra/amdpci/am_smi.py | 167 +++++++++++++++++++++++++++ test/external/external_test_am.py | 2 +- tinygrad/runtime/ops_amd.py | 6 +- tinygrad/runtime/support/am/amdev.py | 10 +- tinygrad/runtime/support/am/ip.py | 21 +++- 5 files changed, 194 insertions(+), 12 deletions(-) create mode 100644 extra/amdpci/am_smi.py diff --git a/extra/amdpci/am_smi.py b/extra/amdpci/am_smi.py new file mode 100644 index 0000000000..9c82db3ca9 --- /dev/null +++ b/extra/amdpci/am_smi.py @@ -0,0 +1,167 @@ +import time, mmap, sys, shutil, os, glob +from tinygrad.helpers import to_mv, DEBUG, colored, ansilen +from tinygrad.runtime.autogen import libc +from tinygrad.runtime.autogen.am import smu_v13_0_0 +from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager +from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA + +AM_VERSION = 0xA0000002 + +def bold(s): return f"\033[1m{s}\033[0m" + +def color_temp(temp): + if temp >= 87: return colored(f"{temp:>4}", "red") + elif temp >= 80: return colored(f"{temp:>4}", "yellow") + return colored(f"{temp:>4}", "white") + +def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan") + +def draw_bar(percentage, width=40, fill='â–ˆ', empty='â–‘'): + filled_width = int(width * percentage) + bar = fill * filled_width + empty * (width - filled_width) + return f'[{bar}] {percentage*100:.1f}%' + +def same_line(strs:list[list[str]], split=8) -> list[str]: + ret = [] + max_width_in_block = [max(ansilen(line) for line in block) for block in strs] + max_height = max(len(block) for block in strs) + for i in range(max_height): + line = [] + for bid, block in enumerate(strs): + if i < len(block): line.append(block[i] + ' ' * (split + max_width_in_block[bid] - ansilen(block[i]))) + else: line.append(' ' * (split + max_width_in_block[bid])) + ret.append(' '.join(line)) + return ret + +def get_bar0_size(pcibus): + resource_file = f"/sys/bus/pci/devices/{pcibus}/resource" + if not os.path.exists(resource_file): raise FileNotFoundError(f"Resource file not found: {resource_file}") + + with open(resource_file, "r") as f: lines = f.readlines() + bar0_info = lines[0].split() + if len(bar0_info) < 3: raise ValueError("Unexpected resource file format for BAR0.") + + start_hex, end_hex, _flags = bar0_info + return int(end_hex, 16) - int(start_hex, 16) + 1 + +class AMSMI(AMDev): + def __init__(self, pcibus, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): + self.pcibus = pcibus + self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar + + self._run_discovery() + self._build_regs() + + if self.reg("regSCRATCH_REG7").read() != AM_VERSION: + raise Exception(f"Unsupported AM version: {self.reg('regSCRATCH_REG7').read():x}") + + self.is_booting, self.smi_dev = True, True + self.partial_boot = True # do not init anything + self.mm = AMMemoryManager(self, self.vram_size) + + # Initialize IP blocks + self.soc21:AM_SOC21 = AM_SOC21(self) + self.gmc:AM_GMC = AM_GMC(self) + self.ih:AM_IH = AM_IH(self) + self.psp:AM_PSP = AM_PSP(self) + self.smu:AM_SMU = AM_SMU(self) + +class SMICtx: + def __init__(self): + self.devs = [] + self.opened_pcidevs = [] + self.opened_pci_resources = {} + self.prev_lines_cnt = 0 + + def _open_am_device(self, pcibus): + if pcibus not in self.opened_pci_resources: + bar_fds = {bar: os.open(f"/sys/bus/pci/devices/{pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]} + bar_size = {0: get_bar0_size(pcibus), 2: os.fstat(bar_fds[2]).st_size, 5: os.fstat(bar_fds[5]).st_size} + + def map_pci_range(bar): + return to_mv(libc.mmap(0, bar_size[bar], mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, bar_fds[bar], 0), bar_size[bar]) + self.opened_pci_resources[pcibus] = (map_pci_range(0), None, map_pci_range(5).cast('I')) + + try: + self.devs.append(AMSMI(pcibus, *self.opened_pci_resources[pcibus])) + except Exception as e: + if DEBUG >= 2: print(f"Failed to open AM device {pcibus}: {e}") + return + + self.opened_pcidevs.append(pcibus) + if DEBUG >= 2: print(f"Opened AM device {pcibus}") + + def rescan_devs(self): + pattern = os.path.join('/tmp', 'am_*.lock') + for d in [f[8:-5] for f in glob.glob(pattern)]: + if d not in self.opened_pcidevs: + self._open_am_device(d) + + for d in self.devs: + if d.reg("regSCRATCH_REG7").read() != AM_VERSION: + self.devs.remove(d) + self.opened_pcidevs.remove(d.pcibus) + os.system('clear') + if DEBUG >= 2: print(f"Removed AM device {d.pcibus}") + + def collect(self): return {d: d.smu.read_metrics() for d in self.devs} + + def draw(self): + terminal_width, _ = shutil.get_terminal_size() + + dev_metrics = self.collect() + dev_content = [] + for dev, metrics in dev_metrics.items(): + device_line = [f"PCIe device: {bold(dev.pcibus)}"] + [""] + activity_line = [f"GFX Activity {draw_bar(metrics.SmuMetrics.AverageGfxActivity / 100, 50)}"] \ + + [f"UCLK Activity {draw_bar(metrics.SmuMetrics.AverageUclkActivity / 100, 50)}"] + [""] + + # draw_metrics_table(metrics, dev) + temps_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_TEMP_e__enumvalues.items() + if k < smu_v13_0_0.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0] + temps_table = ["=== Temps (C) ==="] + [f"{name:<15}: {color_temp(metrics.SmuMetrics.AvgTemperature[k])}" for k, name in temps_keys] + + voltage_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_SVI_PLANE_e__enumvalues.items() if k < smu_v13_0_0.SVI_PLANE_COUNT] + power_table = ["=== Power ==="] \ + + [f"Fan Speed: {metrics.SmuMetrics.AvgFanRpm} RPM"] \ + + [f"Fan Power: {metrics.SmuMetrics.AvgFanPwm} %"] \ + + [f"Power: {metrics.SmuMetrics.AverageSocketPower}W " + + draw_bar(metrics.SmuMetrics.AverageSocketPower / metrics.SmuMetrics.dGPU_W_MAX, 16)] \ + + ["", "=== Voltages ==="] + [f"{name:<24}: {color_voltage(metrics.SmuMetrics.AvgVoltage[k])}" for k, name in voltage_keys] + + frequency_table = ["=== Frequencies ===", + f"GFXCLK Target : {metrics.SmuMetrics.AverageGfxclkFrequencyTarget} MHz", + f"GFXCLK PreDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPreDs} MHz", + f"GFXCLK PostDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPostDs} MHz", + f"FCLK PreDs : {metrics.SmuMetrics.AverageFclkFrequencyPreDs} MHz", + f"FCLK PostDs : {metrics.SmuMetrics.AverageFclkFrequencyPostDs} MHz", + f"MCLK PreDs : {metrics.SmuMetrics.AverageMemclkFrequencyPreDs} MHz", + f"MCLK PostDs : {metrics.SmuMetrics.AverageMemclkFrequencyPostDs} MHz", + f"VCLK0 : {metrics.SmuMetrics.AverageVclk0Frequency} MHz", + f"DCLK0 : {metrics.SmuMetrics.AverageDclk0Frequency} MHz", + f"VCLK1 : {metrics.SmuMetrics.AverageVclk1Frequency} MHz", + f"DCLK1 : {metrics.SmuMetrics.AverageDclk1Frequency} MHz"] + + dev_content.append(device_line + activity_line + same_line([temps_table, power_table, frequency_table])) + + raw_text = 'AM Monitor'.center(terminal_width) + "\n" + "=" * terminal_width + "\n\n" + for i in range(0, len(dev_content), 2): + if i + 1 < len(dev_content): raw_text += '\n'.join(same_line([dev_content[i], dev_content[i+1]])) + else: raw_text += '\n'.join(dev_content[i]) + if i + 2 < len(dev_content): raw_text += "\n" + "=" * terminal_width + "\n\n" + + sys.stdout.write(f'\033[{self.prev_lines_cnt}A') + sys.stdout.flush() + print(raw_text) + + self.prev_lines_cnt = len(raw_text.splitlines()) + 2 + +if __name__ == "__main__": + try: + os.system('clear') + smi_ctx = SMICtx() + while True: + smi_ctx.rescan_devs() + smi_ctx.draw() + time.sleep(1) + except KeyboardInterrupt: print("Exiting...") diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index b526d7cf5c..5d918156ca 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -16,7 +16,7 @@ class FakePCIDev: class FakeAM: def __init__(self): - self.is_booting = True + self.is_booting, self.smi_dev = True, False self.pcidev = FakePCIDev() self.vram = memoryview(bytearray(4 << 30)) self.gmc = FakeGMC() diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 9c20c5b298..e1a821d624 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -487,11 +487,11 @@ class PCIIface: self.pagemap = HWInterface("/proc/self/pagemap", os.O_RDONLY) self.bar_fds = {bar: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]} - self.adev = AMDev(self.pcidev, self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I')) + self.adev = AMDev(self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I')) self.doorbell_cpu_addr = mv_address(dbell) - libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND) - libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND) + libpciaccess.pci_device_cfg_read_u16(self.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND) + libpciaccess.pci_device_cfg_write_u16(self.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND) # TODO: this is for 7900xtx, the only tested card. self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32, diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 328a1280fe..2f541d8766 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -171,7 +171,7 @@ class AMMemoryManager: self.adev, self.vram_size = adev, vram_size self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device - self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) + self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1) def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping: assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" @@ -231,8 +231,8 @@ class AMMemoryManager: def pfree(self, paddr:int): self.pa_allocator.free(paddr) class AMDev: - def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): - self.pcidev, self.devfmt = pcidev, devfmt + def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): + self.devfmt = devfmt self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions @@ -256,8 +256,8 @@ class AMDev: # To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for # all blocks that are initialized only during the initial AM boot. # To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag. - self.is_booting = True # During boot only boot memory can be allocated. This flag is to validate this. - self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1) + self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this. + self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1) # Memory manager & firmware self.mm = AMMemoryManager(self, self.vram_size) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index f29cf3e2e1..b27b30579b 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -102,7 +102,13 @@ class AM_GMC(AM_IP): if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}") class AM_SMU(AM_IP): + def __init__(self, adev): + super().__init__(adev) + self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True) + def init(self): + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) for clck in [0x00000C94, 0x000204E1, 0x000105DC, 0x00050B76, 0x00070B76, 0x00040898, 0x00060898, 0x000308FD]: @@ -118,6 +124,11 @@ class AM_SMU(AM_IP): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms + def read_table(self, table_t, cmd): + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) + return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t))) + def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS) + def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg @@ -240,8 +251,9 @@ class AM_IH(AM_IP): def __init__(self, adev): super().__init__(adev) self.ring_size = 512 << 10 - self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), - (self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] + def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True), + self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)) + self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)] def interrupt_handler(self): _, rwptr_vm, suf, _ = self.rings[0] @@ -318,6 +330,9 @@ class AM_PSP(AM_IP): self.ring_size = 0x10000 self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True) + self.max_tmr_size = 0x1300000 + self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True) + def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0 def init(self): sos_components_load_order = [ @@ -362,7 +377,7 @@ class AM_PSP(AM_IP): # Load TOC and calculate TMR size self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC]) self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size - self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) + assert self.tmr_size <= self.max_tmr_size def _ring_create(self): # If the ring is already created, destroy it From e0e176efbc10b82d50d724b0bf2c566c8aa9410d Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 24 Jan 2025 13:56:51 -0500 Subject: [PATCH 22/44] failed test case for multi rand_like [pr] (#8740) new multi broke multi device dropout --- test/test_multitensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 25f863568d..c1e257aea0 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -651,6 +651,12 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + def test_rand_like_from_alu(self): + a = Tensor.ones(4, 4).shard(devices_2, axis=0) + # TODO: fix this, which will also fix multi device dropout + with self.assertRaises(AssertionError): + (a + a).rand_like() + @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1) From 0c759e1ff6e8621d86268be5eaf2bc91d6f19053 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 24 Jan 2025 14:45:11 -0500 Subject: [PATCH 23/44] add bert to bechmark ci (#8741) with `DISABLE_DROPOUT=1 BERT_LAYERS=2` for now --- .github/workflows/benchmark.yml | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 265971b24a..cf5b117aad 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -299,6 +299,10 @@ jobs: run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt + - name: Run 10 MLPerf Bert training steps (6 gpu) + # TODO: remove DISABLE_DROPOUT once dropout is fixed + # TODO: remove BERT_LAYERS once scheduler is fast + run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (NVIDIA Training) @@ -309,9 +313,10 @@ jobs: train_cifar_bf16.txt train_cifar_wino.txt train_cifar_one_gpu.txt + train_cifar_six_gpu.txt train_resnet.txt train_resnet_one_gpu.txt - train_cifar_six_gpu.txt + train_bert.txt - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py @@ -492,6 +497,10 @@ jobs: run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt + - name: Run 10 MLPerf Bert training steps (6 gpu) + # TODO: remove DISABLE_DROPOUT once dropout is fixed + # TODO: remove BERT_LAYERS once scheduler is fast + run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD Training) @@ -502,9 +511,10 @@ jobs: train_cifar_bf16.txt train_cifar_wino.txt train_cifar_one_gpu.txt + train_cifar_six_gpu.txt train_resnet.txt train_resnet_one_gpu.txt - train_cifar_six_gpu.txt + train_bert.txt - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py From 2f06eccf1da766306626e9d5ee36042e26c135e2 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 25 Jan 2025 00:33:00 +0300 Subject: [PATCH 24/44] am: script and vfio msg (#8742) * am: script and vfio msg * use sysfs bars always for now * tiny chnages --- extra/amdpci/am_smi.py | 4 ++-- extra/amdpci/setup_python_cap.sh | 3 +++ extra/amdpci/setup_vfio.sh | 2 ++ tinygrad/runtime/ops_amd.py | 7 ++----- 4 files changed, 9 insertions(+), 7 deletions(-) create mode 100755 extra/amdpci/setup_python_cap.sh create mode 100755 extra/amdpci/setup_vfio.sh diff --git a/extra/amdpci/am_smi.py b/extra/amdpci/am_smi.py index 9c82db3ca9..f5ffbc6da3 100644 --- a/extra/amdpci/am_smi.py +++ b/extra/amdpci/am_smi.py @@ -12,14 +12,14 @@ def bold(s): return f"\033[1m{s}\033[0m" def color_temp(temp): if temp >= 87: return colored(f"{temp:>4}", "red") elif temp >= 80: return colored(f"{temp:>4}", "yellow") - return colored(f"{temp:>4}", "white") + return f"{temp:>4}" def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan") def draw_bar(percentage, width=40, fill='â–ˆ', empty='â–‘'): filled_width = int(width * percentage) bar = fill * filled_width + empty * (width - filled_width) - return f'[{bar}] {percentage*100:.1f}%' + return f'[{bar}] {percentage*100:5.1f}%' def same_line(strs:list[list[str]], split=8) -> list[str]: ret = [] diff --git a/extra/amdpci/setup_python_cap.sh b/extra/amdpci/setup_python_cap.sh new file mode 100755 index 0000000000..2ef3c5f6b7 --- /dev/null +++ b/extra/amdpci/setup_python_cap.sh @@ -0,0 +1,3 @@ +#!/bin/bash +PYTHON_PATH=$(readlink -f $(which python3)) +sudo setcap 'cap_dac_override,cap_sys_rawio,cap_sys_admin=ep' $PYTHON_PATH diff --git a/extra/amdpci/setup_vfio.sh b/extra/amdpci/setup_vfio.sh new file mode 100755 index 0000000000..6ca786edaa --- /dev/null +++ b/extra/amdpci/setup_vfio.sh @@ -0,0 +1,2 @@ +#!/bin/bash +sudo modprobe vfio-pci disable_idle_d3=1 diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index e1a821d624..360e408cf3 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -464,7 +464,7 @@ class PCIIface: iommu_group = HWInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1] except OSError: - if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (not inserted or no-iommu mode is not supported).") + if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (run `sudo modprobe vfio-pci`).") PCIIface.vfio = False # Init vfio for the device @@ -498,10 +498,7 @@ class PCIIface: 'array_count': 12, 'simd_arrays_per_engine': 2, 'lds_size_in_kb': 64} def _map_pci_range(self, bar, off=0, addr=0, size=None): - if PCIIface.vfio: - vfio.VFIO_DEVICE_GET_REGION_INFO(self.vfio_dev, reg:=vfio.struct_vfio_region_info(argsz=ctypes.sizeof(vfio.struct_vfio_region_info), index=bar)) - fd, sz, off = self.vfio_dev, size or reg.size, reg.offset + off - else: fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size + fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size return to_mv(fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz) def alloc(self, size:int, host=False, uncached=False, cpu_access=False): From cb0978b3778d8f809b9c8ebd37d0d2d3398297e0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 25 Jan 2025 07:28:43 +0900 Subject: [PATCH 25/44] add Ops.CONTIGUOUS_BACKWARD (#8743) --- tinygrad/engine/schedule.py | 6 +++--- tinygrad/function.py | 2 +- tinygrad/gradient.py | 1 + tinygrad/ops.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 58ca954abf..801ffdcfd9 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -37,7 +37,7 @@ tensor_uop_spec = PatternMatcher([ # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes - (UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), + (UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), # COPY # NOTE: the arg here specifies clone=True, which prevents folding same device copy @@ -366,8 +366,8 @@ sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # DETACH is a NOOP here - (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), diff --git a/tinygrad/function.py b/tinygrad/function.py index 5527870711..73a963dd48 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -10,7 +10,7 @@ class Contiguous(Function): def backward(self, grad_output:UOp) -> UOp: return grad_output class ContiguousBackward(Function): - def forward(self, x:UOp) -> UOp: return x + def forward(self, x:UOp) -> UOp: return x.contiguous_backward() def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous() class Cast(Function): diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index f64f443858..23df056299 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -28,6 +28,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))), (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), + (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)), (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c028c82745..b9291dcc27 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait): # the order of these Ops controls the order of the toposort class Ops(FastEnum): # uops that aren't rendered - SINK = auto(); CONTIGUOUS = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 + SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 # TODO: empty continues to exist because of tensor EMPTY = auto() @@ -416,6 +416,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return self.alu(Ops.CONTIGUOUS) + def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) # *** from MultiLazyBuffer *** From e2b380b743a3e938a8505a4df652765c9dae74ce Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 24 Jan 2025 20:47:27 -0500 Subject: [PATCH 26/44] make UOp.multi real a tuple instead of list [pr] (#8744) tuple is immutable. also updated test_rand_like_from_alu test --- test/test_multitensor.py | 6 +++++- tinygrad/multi.py | 4 ++-- tinygrad/ops.py | 4 ++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index c1e257aea0..5e30eadcf7 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -652,11 +652,15 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.lazydata.axis, t2.lazydata.axis) def test_rand_like_from_alu(self): - a = Tensor.ones(4, 4).shard(devices_2, axis=0) # TODO: fix this, which will also fix multi device dropout + a = Tensor.ones(4, 4).shard(devices_2, axis=0) with self.assertRaises(AssertionError): (a + a).rand_like() + b = Tensor.empty(4, 4).shard(devices_2, axis=None) + with self.assertRaises(AssertionError): + (a + b).rand_like() + @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 8a7b9d04a0..6cf587c01f 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -52,7 +52,7 @@ def alu_multi(root:UOp): axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) srcs:list[list[UOp]] = [] not_all_real = not all(all(mlb.real) for mlb in msrcs) - new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else msrcs[0].real + new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real assert any(new_real), "output contains no real lb" for mlb in msrcs: if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) @@ -124,7 +124,7 @@ def shrink_multi(root:UOp, multi:UOp): idx = multi.bounds.index(root.arg[multi.axis]) # zero out other lbs to not create lb reference return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)], - axis=multi.axis, real=[i==idx for i in range(len(multi.src))]) + axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src)))) return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], axis=multi.axis, real=multi.real) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index b9291dcc27..bb0832b1ab 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -420,10 +420,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # *** from MultiLazyBuffer *** - def multi(self, *more:UOp, axis:int|None, real:list[bool]|None=None): + def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None): parents = (self,)+more assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" - return UOp(Ops.MULTI, self.dtype, parents, (axis, tuple(real if real is not None else [True]*len(parents)))) + return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents))) @property def bounds(self): From a037201168eb3ce2f3b7dae6b388ae2c456b0bfc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 25 Jan 2025 07:33:31 -0500 Subject: [PATCH 27/44] test_viz cleanups + move to /unit directory (#8746) * test_viz cleanups + move to /unit directory * lint --- test/{ => unit}/test_viz.py | 125 ++++++++++++++---------------------- 1 file changed, 47 insertions(+), 78 deletions(-) rename test/{ => unit}/test_viz.py (61%) diff --git a/test/test_viz.py b/test/unit/test_viz.py similarity index 61% rename from test/test_viz.py rename to test/unit/test_viz.py index 2723a411c4..f65f347ebc 100644 --- a/test/test_viz.py +++ b/test/unit/test_viz.py @@ -1,72 +1,52 @@ -from typing import Dict, List, Optional import unittest, decimal, json from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic -from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys -from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic +from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry -from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto +from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto -@track_rewrites(named=True) -def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs) - -def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]: - rewrite(sink, pm, **kwargs) - assert len(contexts) == 1 - assert len(contexts[0]) == 1 - k = get_metadata(keys, contexts)[0][0] - g = get_details(*k) - return g.uops[1:] +# NOTE: VIZ tests always use the tracked PatternMatcher instance +symbolic = TrackedPatternMatcher(symbolic.patterns) class TestViz(unittest.TestCase): def setUp(self): + # clear the global context contexts.clear() keys.clear() + _name_cnt.clear() self.tms = TRACK_MATCH_STATS.value TRACK_MATCH_STATS.value = 2 def tearDown(self): TRACK_MATCH_STATS.value = self.tms def test_viz_simple(self): - pm = PatternMatcher([ - (UPat.var("x")*1, lambda x:x), - ]) - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - uops = helper_test_viz(a*1, pm) - self.assertEqual(len(uops), 1) - self.assertEqual(uops[0], a) + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(sink): return graph_rewrite(sink, symbolic) + test(a*1) + ret = get_metadata(keys, contexts) + self.assertEqual(len(ret), 1) + self.assertEqual(ret[0][0][0], "test_1") + self.assertEqual(len(ret[0][0][2].upats), 1) - def test_rewrite_twice(self): - pm = PatternMatcher([ - (UPat.var("x")+UPat.var("x"), lambda x:x*2), - (UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))), - ]) - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - uops = helper_test_viz(a+a, pm) - self.assertEqual(len(uops), 2) - self.assertEqual(uops[0], a*2) - self.assertEqual(uops[1], graph_rewrite(a+a, pm)) - - def test_rewrite_with_ctx(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), ShapeTracker.from_shape((1, 1)).to_uop())) - b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), ShapeTracker.from_shape((1, 1)).to_uop())) - def store_load(ctx:Dict[UOp, None], glbl, st) -> Optional[UOp]: - if glbl in ctx: return None - ctx[glbl] = None - return UOp.store(glbl, ShapeTracker.from_shape(st.shape).to_uop()) - pm = PatternMatcher([ - (UPat.load(UPat(Ops.DEFINE_GLOBAL, name="glbl"), UPat.var("st")), store_load), - ]) - uops = helper_test_viz(a+b, pm, ctx={}) - self.assertEqual(len(uops), 2) - self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {})) + def test_track_two_rewrites(self): + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(sink): return graph_rewrite(sink, symbolic) + test((a+a)*1) + ret = get_metadata(keys, contexts) + self.assertEqual(len(ret), 1) # one context + self.assertEqual(len(ret[0]), 1) # one graph_rewrite call in context + key, _, val = ret[0][0] + self.assertEqual(key, "test_1") + self.assertEqual(len(val.upats), 2) # two upats applied def test_track_rewrites(self): - simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites(named=True) - def do_rewrite(x:UOp): return graph_rewrite(x, simple) - ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) - do_rewrite(ld*1) - do_rewrite(ld*2) + def do_rewrite(x:UOp): return graph_rewrite(x, symbolic) + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 4) + do_rewrite(a*1) + do_rewrite(a*b) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 2) key, _, m = ret[0][0] @@ -77,42 +57,32 @@ class TestViz(unittest.TestCase): self.assertEqual(len(m.upats), 0) def test_track_rewrites_with_exception(self): - simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites() def do_rewrite(x:UOp): - x = graph_rewrite(x, simple) # NOTE: viz tracks this + x = graph_rewrite(x, symbolic) # NOTE: viz tracks this raise Exception("test") - ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) - with self.assertRaises(Exception): do_rewrite(ld*1) + a = UOp.variable("a", 0, 10) + with self.assertRaises(Exception): do_rewrite(a*1) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - def test_fold_const(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - graph = uop_to_json(a) - assert not any(v[0].startswith("CONST") for v in graph.values()) - assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 + # NOTE: CONST UOps do not get nodes in the graph + def test_dont_create_const_nodes(self): + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 4) + self.assertEqual(len(uop_to_json(a*1)), 2) + self.assertEqual(len(uop_to_json(a*b)), 3) @unittest.skip("TODO: bring this back with better testing") def test_bottom_up_rewrite(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - n1 = a.sin() - uop = n1.sin() - pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) - ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True) - self.assertEqual(len(ret), 2) - self.assertIs(ret[0], a.sin().sqrt()) # first rewrite - self.assertIs(ret[1], a.sqrt().sqrt()) # second one - - def test_top_down_rewrite(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - n1 = a.sin() - uop = n1.sin() - pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) - # if it wasn't bottom_up, it's rewritten once - ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False) + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 10) + c = UOp.variable("c", 0, 10) + UOp.substitute(a+b, {a+b:c}) + ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - self.assertIs(ret[0], a.sqrt().sin()) # only rewrite + _, _, vals = ret[0][0] + self.assertEqual(len(vals.upats), 1) # NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ def test_rewrite_without_context(self): @@ -211,6 +181,5 @@ class TextVizProfiler(unittest.TestCase): self.assertEqual(j['traceEvents'][7]['dur'], 4) self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid']) - if __name__ == "__main__": unittest.main() From c74c5901a81d8c69a09058b6e7c2d2e38f3ac273 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Sat, 25 Jan 2025 19:06:35 +0300 Subject: [PATCH 28/44] am disable bind (#8747) --- tinygrad/runtime/ops_amd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 360e408cf3..cb08f4aa2d 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -198,12 +198,12 @@ class AMDCopyQueue(HWQueue): return self def bind(self, dev:AMDDevice): - if not dev.driverless: return + if not getenv("AMD_SDMA_BIND", 0) or not dev.driverless: return self.binded_device = dev self.hw_page = dev.allocator.alloc((qsz:=round_up(len(self._q), 8)) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True)) hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I") - for i, value in enumerate(self._q): hw_view[i] = value + for i in range(qsz): hw_view[i] = self._q[i] if i < len(self._q) else 0 self.indirect_cmd = [amd_gpu.SDMA_OP_INDIRECT | amd_gpu.SDMA_PKT_INDIRECT_HEADER_VMID(0), *data64_le(self.hw_page.va_addr), qsz, *data64_le(0)] self._q, self.cmd_sizes = hw_view, [len(self.indirect_cmd)] From 0e42befc6e1a0e568645a02f4481af3a0ca54f2a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 25 Jan 2025 12:41:57 -0500 Subject: [PATCH 29/44] viz cleanups 2 [pr] (#8748) * viz cleanups 2 [pr] * test_viz updates --- test/unit/test_viz.py | 36 +++++++++++----- tinygrad/ops.py | 9 ++-- tinygrad/viz/index.html | 12 +++--- tinygrad/viz/serve.py | 92 ++++++++++++++--------------------------- 4 files changed, 68 insertions(+), 81 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index f65f347ebc..ff84813d1e 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -25,8 +25,9 @@ class TestViz(unittest.TestCase): test(a*1) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - self.assertEqual(ret[0][0][0], "test_1") - self.assertEqual(len(ret[0][0][2].upats), 1) + key, val = ret[0] + self.assertEqual(key, "test_1") + self.assertEqual(val[0]["match_count"], 1) def test_track_two_rewrites(self): a = UOp.variable("a", 0, 10) @@ -34,11 +35,26 @@ class TestViz(unittest.TestCase): def test(sink): return graph_rewrite(sink, symbolic) test((a+a)*1) ret = get_metadata(keys, contexts) - self.assertEqual(len(ret), 1) # one context - self.assertEqual(len(ret[0]), 1) # one graph_rewrite call in context - key, _, val = ret[0][0] + key, val = ret[0] + self.assertEqual(len(ret), 1) # one context + self.assertEqual(len(val), 1) # one graph_rewrite call in context self.assertEqual(key, "test_1") - self.assertEqual(len(val.upats), 2) # two upats applied + self.assertEqual(val[0]["match_count"], 2) # two upats applied + + def test_track_multiple_calls_one_ctx(self): + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(a, b): + a = graph_rewrite(a, symbolic) + b = graph_rewrite(b, symbolic) + test(a*1, a*5) + ret = get_metadata(keys, contexts) + key, val = ret[0] + self.assertEqual(len(ret), 1) # one context + self.assertEqual(len(val), 2) # two graph_rewrite calls in context + self.assertEqual(key, "test_1") + self.assertEqual(val[0]["match_count"], 1) # one rewrite for a*0 + self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5 def test_track_rewrites(self): @track_rewrites(named=True) @@ -49,12 +65,12 @@ class TestViz(unittest.TestCase): do_rewrite(a*b) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 2) - key, _, m = ret[0][0] + key, m = ret[0] self.assertEqual(key, "do_rewrite_1") - self.assertEqual(len(m.upats), 1) - key, _, m = ret[1][0] + self.assertEqual(m[0]["match_count"], 1) + key, m = ret[1] self.assertEqual(key, "do_rewrite_2") - self.assertEqual(len(m.upats), 0) + self.assertEqual(m[0]["match_count"], 0) def test_track_rewrites_with_exception(self): @track_rewrites() diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bb0832b1ab..4549ce903e 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -823,9 +823,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) match_stats:dict[UPat, list[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedGraphRewrite: - loc: tuple[str, int] # location that called graph_rewrite - sink: UOp # the sink input to graph_rewrite - matches: list[tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # before+after of all the matches + loc: tuple[str, int] # location that called graph_rewrite + sink: UOp # the sink input to graph_rewrite + matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] _name_cnt:dict[str, int] = {} @@ -856,10 +856,9 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p, et)) + if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st - if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0 and len(tracked_ctxs[-1]) != 0: tracked_ctxs[-1][-1].matches.append((uop, None, None, 0)) return None if TRACK_MATCH_STATS: diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 0f94ebe655..08b7f4d00a 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -301,17 +301,17 @@ const kernelListParent = document.querySelector(".container.kernel-list-parent"); const kernelList = document.querySelector(".container.kernel-list"); kernelList.innerHTML = ""; - kernels.forEach((k, i) => { + kernels.forEach(([key, items], i) => { const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "", style: "overflow-x: auto; cursor: initial;" }); if (i === currentKernel) { requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" })); } - const p = Object.assign(document.createElement("p"), { id: `kernel-${k[0].kernel_name}`, innerText: k[0].kernel_name ?? "UNPARENTED", + const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerText: key ?? "UNPARENTED", style: "cursor: pointer;"}); kernelUl.appendChild(p) - k.forEach((u, j) => { - const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.upats.length}`, key: `uop-rewrite-${j}`, + items.forEach((u, j) => { + const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" }) if (j === currentUOp) { requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" })); @@ -460,7 +460,7 @@ event.preventDefault() currentUOp = 0; currentRewrite = 0; - currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1) + currentKernel = Math.min(kernels.length-1, currentKernel+1); return main() } } @@ -486,7 +486,7 @@ if (event.key == "ArrowDown") { event.preventDefault() currentRewrite = 0; - const totalUOps = kernels[currentKernel].length-1; + const totalUOps = kernels[currentKernel][1].length-1; currentUOp = Math.min(totalUOps, currentUOp+1) main() } diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index fdceca9c0b..02dbee0a5e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -2,8 +2,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse -from dataclasses import asdict, dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, TypedDict from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp from tinygrad.codegen.kernel import Kernel @@ -17,50 +16,26 @@ uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB" **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"} -# ** API spec +# VIZ API -@dataclass -class GraphRewriteMetadata: - """Overview of a tracked rewrite to viz the sidebar""" - loc: tuple[str, int] - """File_path, Lineno""" - code_line: str - """The Python line calling graph_rewrite""" - kernel_name: str - """The kernel calling graph_rewrite""" - upats: list[tuple[tuple[str, int], str, float]] - """List of all the applied UPats""" +class GraphRewriteMetadata(TypedDict): + loc: tuple[str, int] # [path, lineno] calling graph_rewrite + match_count: int # total match count in this context -@dataclass class GraphRewriteDetails(GraphRewriteMetadata): - """Full details about a single call to graph_rewrite""" - uops: list[UOp] - graphs: list[dict] - """Sink at every step of graph_rewrite + the json serialized version""" - diffs: list[list[str]] - """.diff style before and after of the rewritten UOp child""" - changed_nodes: list[list[int]] - """Nodes that changed at every step of graph_rewrite""" - kernel_code: Optional[str] - """The program after all rewrites""" - -# ** API functions + graphs: list[dict] # JSON serialized UOp at every rewrite step + uops: list[str] # strigified UOp at every rewrite step + diffs: list[list[str]] # string diff of the single UOp that changed + changed_nodes: list[list[int]] # the changed UOp id + all its parents ids + code_line: str # source code calling graph_rewrite + kernel_code: str|None # optionally render the final kernel code + upats: list[tuple[tuple[str, int], str]] # NOTE: if any extra rendering in VIZ fails, we don't crash def pcall(fxn:Callable[..., str], *args, **kwargs) -> str: try: return fxn(*args, **kwargs) except Exception as e: return f"ERROR: {e}" -def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]: - kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {} - for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"): - name = to_function_name(k.name) if isinstance(k, Kernel) else str(k) - for ctx in ctxs: - if ctx.sink.op is Ops.CONST: continue - upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None] - kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) - return list(kernels.values()) - def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]: assert isinstance(x, UOp) graph: dict[int, tuple[str, str, list[int], str, str]] = {} @@ -80,27 +55,27 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]: else: label += f"\n{x.op.name}{idx} {x.arg}" graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph + +def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]: + return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), + [{"loc": v.loc, "match_count": len(v.matches)} for v in vals]) for k,vals in zip(keys, contexts)] + @functools.lru_cache(None) def _prg(k:Kernel): return k.to_program().src def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), uops=[ctx.sink], diffs=[], changed_nodes=[], - kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[]) + ret:GraphRewriteDetails = {"uops":[pcall(str, sink:=ctx.sink)], "graphs":[uop_to_json(sink)], "code_line":lines(ctx.loc[0])[ctx.loc[1]-1].strip(), + "kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "diffs":[], "upats":[], "changed_nodes":[], **metadata} replaces: dict[UOp, UOp] = {} - g.graphs.append(uop_to_json(sink:=g.uops[0])) - for i,(u0,u1,upat,_) in enumerate(tqdm(ctx.matches)): - # if the match didn't result in a rewrite we move forward - if u1 is None: continue + for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches)): replaces[u0] = u1 - # first, rewrite this UOp with the current rewrite + all the matches in replaces new_sink = sink.substitute(replaces) - # sanity check - if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") - # update ret data - g.graphs.append(new_sink_js:=uop_to_json(new_sink)) - g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js]) - g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) - g.uops.append(sink:=new_sink) - return g + ret["graphs"].append(new_sink_js:=uop_to_json(new_sink)) + ret["changed_nodes"].append([id(x) for x in u1.toposort if id(x) in new_sink_js]) + ret["diffs"].append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) + ret["upats"].append((upat.location, upat.printable())) + # TODO: this is O(n^2)! + ret["uops"].append(str(sink:=new_sink)) + return ret # Profiler API devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {} @@ -150,13 +125,9 @@ class Handler(BaseHTTPRequestHandler): elif url.path == "/kernels": query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: - g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) - # TODO: this is O(n^2)! - uops_strs = [pcall(str,x) for x in tqdm(g.uops)] - # NOTE: don't use asdict because it's reserializing the uops - jret: Any = {"loc": g.loc, "code_line": g.code_line, "kernel_name": g.kernel_name, "upats": g.upats, - "uops": uops_strs, "graphs": g.graphs, "diffs": g.diffs, "changed_nodes": g.changed_nodes, "kernel_code": g.kernel_code} - else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels] + kidx, ridx = int(qkernel[0]), int(query["idx"][0]) + jret:Any = get_details(contexts[0][kidx], contexts[1][kidx][ridx], kernels[int(qkernel[0])][1][int(query["idx"][0])]) + else: jret = kernels ret, content_type = json.dumps(jret).encode(), "application/json" elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json" else: status_code = 404 @@ -198,10 +169,11 @@ if __name__ == "__main__": contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) + # NOTE: this context is a tuple of list[keys] and list[values] kernels = get_metadata(*contexts) if contexts is not None else [] if getenv("FUZZ_VIZ"): - ret = [get_details(*args) for v in tqdm(kernels) for args in v] + ret = [get_details(contexts[0][i], contexts[1][i][j], args) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])] print(f"fuzzed {len(ret)} rewrite details") perfetto_profile = to_perfetto(profile) if profile is not None else None From 0ffd572e1e562f5580945d46fb0b127f7af70916 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 08:41:00 +0900 Subject: [PATCH 30/44] fix multi with no real srcs (#8749) --- test/test_multitensor.py | 9 +++------ tinygrad/multi.py | 7 ++----- tinygrad/ops.py | 4 ++-- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 5e30eadcf7..a423d79731 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -344,9 +344,7 @@ class TestMultiTensor(unittest.TestCase): # NOTE: this is failing on LLVM CI, no idea why. Works locally. @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow") def test_data_parallel_resnet(self): - import sys, pathlib - sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) - from resnet import ResNet18 + from extra.models.resnet import ResNet18 fake_image = Tensor.rand((2, 3, 224//8, 224//8)) fake_image_sharded = fake_image.shard(devices_2, axis=0) @@ -363,9 +361,7 @@ class TestMultiTensor(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM") def test_data_parallel_resnet_train_step(self): - import sys, pathlib - sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) - from resnet import ResNet18 + from extra.models.resnet import ResNet18 from tinygrad.nn.optim import LARS fake_image = Tensor.rand((2, 3, 224//8, 224//8)) @@ -899,6 +895,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): np.testing.assert_equal((a+a).numpy(), na+na) np.testing.assert_equal((b+b).numpy(), nb+nb) + @unittest.skip("why didn't this work?") def test_add_two_partitions(self): t = Tensor.arange(64).reshape(8, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 6cf587c01f..c6741b0c98 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -53,17 +53,14 @@ def alu_multi(root:UOp): srcs:list[list[UOp]] = [] not_all_real = not all(all(mlb.real) for mlb in msrcs) new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real - assert any(new_real), "output contains no real lb" for mlb in msrcs: if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) else: assert axis is not None and bounds is not None if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds)) else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds)) - new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(root.op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} - # NOTE: const dtype should match real - new_dtype = next(iter(new_real_lbs.values())).dtype - new_lbs = [new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))] + new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)] + new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed? return UOp.multi(*new_lbs, axis=axis, real=new_real) def reduce_multi(root:UOp, multi:UOp): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 4549ce903e..950589a80a 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -896,7 +896,7 @@ class RewriteContext: self.replace: dict[UOp, UOp] = {} def top_down_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn - new_src = tuple(map(self.top_down_rewrite, n.src)) + new_src = tuple([self.top_down_rewrite(x) for x in n.src]) new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg) self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n) return ret @@ -904,7 +904,7 @@ class RewriteContext: if (rn := self.replace.get(n)) is not None: return rn new_n: UOp|None = n while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx) - new_src = tuple(map(self.bottom_up_rewrite, last_n.src)) + new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src]) self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg)) return ret From b4bf6a7dea52dc33cdf8cf278d6d801a78d5b7d8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 09:12:16 +0900 Subject: [PATCH 31/44] switch backward to use gradient [pr] (#8235) * switch backward to use gradient [pr] * set device correctly, dedup * why does that fail? * add noop cast * simple backward * fix beautiful_mnist * touchups * set in compute_gradient * uop_count * uop_count was wrong * collections * no note * skip that test * update sched kernel counts * train mnist is 65 * fix metadata and gc * fixes * materialize_grads * no pathlib stuff * add contiguous_backward, fix bugs * add some realize * fix multi --- test/models/test_real_world.py | 2 +- test/test_arange.py | 5 +-- test/test_gc.py | 2 ++ test/test_linearizer.py | 1 - test/test_linearizer_failures.py | 2 +- test/test_ops.py | 6 ++++ test/test_schedule.py | 20 +++++++++--- test/test_tensor.py | 4 +-- test/unit/test_gradient.py | 4 +-- tinygrad/gradient.py | 5 +-- tinygrad/tensor.py | 56 +++++++++----------------------- 11 files changed, 50 insertions(+), 57 deletions(-) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index a213e5ec88..7aa3253a07 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65) @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow") def test_train_cifar(self): diff --git a/test/test_arange.py b/test/test_arange.py index 07512ae1b6..2229ac847f 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase): # llama3 is 128256 vocab_size, embed_size = (10, 3) if CI else (32000, 4096) emb = nn.Embedding(vocab_size, embed_size) - emb_w = emb.weight.numpy() + # TODO: why is a new realize needed here + emb_w = emb.weight.realize().numpy() x = Tensor([1,2,3,4]) with Context(NOOPT=noopt, FUSE_ARANGE=1): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 3) + self.assertEqual(GlobalCounters.kernel_count, 2) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_gc.py b/test/test_gc.py index 010f59039f..0929a81394 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -8,9 +8,11 @@ from tinygrad.ops import UOp from tinygrad.tensor import Tensor def tensors_allocated(): + gc.collect() return sum([isinstance(x, Tensor) for x in gc.get_objects()]) def bufs_allocated(): + gc.collect() return sum([isinstance(x, Buffer) for x in gc.get_objects()]) class TestGC(unittest.TestCase): diff --git a/test/test_linearizer.py b/test/test_linearizer.py index da2995d37d..c0bdc2c6f8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 - @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 80ca7fd6e8..8570dac761 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] - helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"]) + helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") diff --git a/test/test_ops.py b/test/test_ops.py index d342a71df4..020f168b8d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -352,12 +352,15 @@ class TestOps(unittest.TestCase): def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y) def test_cmp_ne_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 != t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt != 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) @@ -365,12 +368,15 @@ class TestOps(unittest.TestCase): np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5) def test_cmp_lt_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 < t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt < 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) diff --git a/test/test_schedule.py b/test/test_schedule.py index d852d2345e..dee2a7ae78 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase): check_schedule(out, 2) # multireduce spec + @unittest.skip("these two Tensors are the same") def test_example_matmul(self): x = Tensor.eye(64, requires_grad=True) y = Tensor.eye(64, requires_grad=True) @@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_contig(self): + x = Tensor.eye(64, requires_grad=True).contiguous().realize() + y = Tensor.eye(64, requires_grad=True).contiguous().realize() + z = y.matmul(x).sum() + z.backward() + out = x.grad.contiguous() + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_same(self): x = Tensor.eye(64, requires_grad=True) z = x.matmul(x).sum() @@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 13) + check_schedule(opt.schedule_step(), 14) def test_sgd_conv_fuse(self): with Tensor.train(): @@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 8) + check_schedule(opt.schedule_step(), 9) def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): @@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 10) + check_schedule(opt.schedule_step(), 11) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 18) + check_schedule(opt.schedule_step(), 21) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): @@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15) + with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): diff --git a/test/test_tensor.py b/test/test_tensor.py index d73e55e36a..b1b90ea44b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -770,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) def test_complex_backward(self): - x = Tensor.rand(3, requires_grad=True) - y = Tensor.rand(3, requires_grad=True) + x = Tensor.rand(3, requires_grad=True).realize() + y = Tensor.rand(3, requires_grad=True).realize() out = (x.relu() * y.sigmoid()).sum() self.assertEqual(out.lazydata.metadata.name, "sum") out.backward() diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index a9a41eace0..f7a99982fa 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase): def _test_one_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) - gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x] + gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x] gf = jax.grad(f if jf is None else jf) for val in [-5., -2.0, 0.0, 2.0, 5.]: @@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase): def _test_two_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float) - grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y]) + grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y])) gx, gy = grads[x], grads[y] gf = jax.grad(f if jf is None else jf, argnums=(0, 1)) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 23df056299..86c9f0fa63 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -1,7 +1,7 @@ from typing import cast, Iterator -import math, functools +import math, functools, dataclasses from tinygrad.dtype import dtypes, sum_acc_dtype -from tinygrad.ops import UOp, PatternMatcher, UPat, Ops +from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): @@ -66,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp if v is None: continue if k in grads: grads[k] = grads[k] + v else: grads[k] = v + if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True) return grads diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 2c47162c26..826ade9242 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,11 +1,11 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref +import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap +from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap from tinygrad.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element @@ -293,7 +293,6 @@ class Tensor(SimpleMathTrait): self.contiguous().realize().lazydata.base.realized.copyin(x._data()) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" @@ -901,7 +900,7 @@ class Tensor(SimpleMathTrait): # ***** toposort and backward pass ***** - def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]: + def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]: """ Compute the gradient of the targets with respect to self. @@ -922,23 +921,14 @@ class Tensor(SimpleMathTrait): grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) ret = [] for x in target_uops: - if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + if (y:=grads.get(x)) is None: + if materialize_grads: y = x.const_like(0) + else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") ret.append(y) rets.append(ret) # create returned Tensors return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - def _deepwalk(self) -> list[Tensor]: - def _walk(node:Tensor, visited:set[Tensor]): - visited.add(node) - # if tensor is not leaf, reset grad - if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None - if ctx: - for i in cast(Function, node._ctx).parents: - if i not in visited: yield from _walk(i, visited) - yield node - return list(_walk(self, set())) - def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor: """ Propagates the gradient of a tensor backwards through the computation graph. @@ -950,30 +940,14 @@ class Tensor(SimpleMathTrait): print(t.grad.numpy()) ``` """ - toposorted = self._deepwalk() - if gradient is None: - assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" - # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous - # this is "implicit gradient creation" - gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) - - toposort_uop = self.lazydata.toposort - assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}" - self.grad = gradient - for t0 in reversed(toposorted): - if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") - ctx = cast(Function, t0._ctx) - token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None) - grads = ctx.backward(t0.grad.lazydata) - _METADATA.reset(token) - grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None - for g in ([grads] if len(ctx.parents) == 1 else grads)] - for t, g in zip(ctx.parents, grads): - if g is not None and t.requires_grad: - assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - assert t.lazydata in toposort_uop, f"grad uop must have a path from self\ngrad uop: {t.lazydata}" - t.grad = g if t.grad is None else (t.grad + g) - if not retain_graph: del t0._ctx + all_uops = self.lazydata.toposort + tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ + t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] + # clear contexts + for t in tensors_need_grad: t._ctx = None + for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): + assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" + t.grad = g if t.grad is None else (t.grad + g) return self # ***** movement low level ops ***** @@ -3993,5 +3967,5 @@ def _metadata_wrapper(fn): if TRACEMETA >= 1: for name, fn in inspect.getmembers(Tensor, inspect.isfunction): - if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue + if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn))) From 1b4618e2575f455263561760263468dee431409d Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 09:30:55 +0900 Subject: [PATCH 32/44] gradient cleanup (#8750) * switch backward to use gradient [pr] * set device correctly, dedup * why does that fail? * add noop cast * simple backward * fix beautiful_mnist * touchups * set in compute_gradient * uop_count * uop_count was wrong * collections * no note * skip that test * update sched kernel counts * train mnist is 65 * fix metadata and gc * fixes * materialize_grads * no pathlib stuff * add contiguous_backward, fix bugs * add some realize * fix multi * remove unused backward passes [pr] * lower line count --- .github/workflows/test.yml | 4 +- tinygrad/function.py | 135 ++++++------------------------------- 2 files changed, 22 insertions(+), 117 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c95325349e..be82930b62 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -243,8 +243,8 @@ jobs: run: | PYTHONPATH="." python test/external/fuzz_shapetracker.py PYTHONPATH="." python test/external/fuzz_shapetracker_math.py - - name: Repo line count < 11100 lines - run: MAX_LINE_COUNT=11100 python sz.py + - name: Repo line count < 11000 lines + run: MAX_LINE_COUNT=11000 python sz.py testopencl: strategy: diff --git a/tinygrad/function.py b/tinygrad/function.py index 73a963dd48..af5cecb8eb 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -1,86 +1,49 @@ """This is where the forwards and backwards passes live.""" import math -from tinygrad.helpers import argsort -from tinygrad.dtype import dtypes, DType, sum_acc_dtype -from tinygrad.ops import Ops, resolve, sint, UOp +from tinygrad.dtype import DType +from tinygrad.ops import Ops, sint, UOp from tinygrad.tensor import Function class Contiguous(Function): def forward(self, x:UOp) -> UOp: return x.contiguous() - def backward(self, grad_output:UOp) -> UOp: return grad_output class ContiguousBackward(Function): def forward(self, x:UOp) -> UOp: return x.contiguous_backward() - def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous() class Cast(Function): - def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: - self.input_dtype, self.bitcast = x.dtype, bitcast - return x.bitcast(dtype) if self.bitcast else x.cast(dtype) - - def backward(self, grad_output:UOp) -> UOp: - if self.bitcast: raise RuntimeError("bitcast cannot backward") - return grad_output.cast(self.input_dtype) + def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: return x.bitcast(dtype) if bitcast else x.cast(dtype) # ************* unary ops ************* class Reciprocal(Function): - def forward(self, x:UOp) -> UOp: - self.ret = x.reciprocal() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret + def forward(self, x:UOp) -> UOp: return x.reciprocal() class Sin(Function): - def forward(self, x:UOp) -> UOp: - self.x = x - return x.sin() - - def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output + def forward(self, x:UOp) -> UOp: return x.sin() class Relu(Function): - def forward(self, x:UOp) -> UOp: - self.ret = (x>0).where(x, 0) - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output + def forward(self, x:UOp) -> UOp: return (x>0).where(x, 0) class Log(Function): - def forward(self, x:UOp) -> UOp: - self.x = x - return x.log2() * math.log(2) - - def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x + def forward(self, x:UOp) -> UOp: return x.log2() * math.log(2) class Exp(Function): - def forward(self, x:UOp) -> UOp: - self.ret = (x * (1/math.log(2))).exp2() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output + def forward(self, x:UOp) -> UOp: return (x * (1/math.log(2))).exp2() class Sqrt(Function): - def forward(self, x:UOp) -> UOp: - self.ret = x.sqrt() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2) + def forward(self, x:UOp) -> UOp: return x.sqrt() class Sign(Function): # NOTE: the x*0 is to match torch behavior without function.py def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0 - # backward always return 0 to match torch - def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0) # ************* binary ops ************* class Less(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x tuple[UOp|None, UOp|None]: return None, None class Neq(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y) - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None class Xor(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x^y @@ -97,18 +60,8 @@ class Threefry(Function): class Add(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x+y - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: - return grad_output if self.needs_input_grad[0] else None, \ - grad_output if self.needs_input_grad[1] else None - class Mul(Function): - def forward(self, x:UOp, y:UOp) -> UOp: - self.x, self.y = x, y - return x * y - - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: - return (self.y * grad_output) if self.needs_input_grad[0] else None, \ - (self.x * grad_output) if self.needs_input_grad[1] else None + def forward(self, x:UOp, y:UOp) -> UOp: return x * y class IDiv(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x // y @@ -119,85 +72,37 @@ class Mod(Function): # ************* ternary ops ************* class Where(Function): - def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: - self.x = x - return self.x.where(y, z) + def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: return x.where(y, z) - def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]: - return None, \ - self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \ - self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None # ************* reduce ops ************* class Sum(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.input_shape = x.shape - return x.r(Ops.ADD, axis) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape) + def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.ADD, axis) class Prod(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.x, self.ret = x, x.r(Ops.MUL, axis) - return self.ret - - def backward(self, grad_output:UOp) -> UOp: - return (grad_output * self.ret).expand(self.x.shape) / self.x + def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MUL, axis) class Max(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis - return self.ret - - def backward(self, grad_output:UOp) -> UOp: - # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype) - div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape) - return (max_is_1s/div) * grad_output.expand(self.x.shape) + def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MAX, axis) # ************* movement ops ************* # NOTE: this is sum in reverse class Expand(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: - self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so)) - return x.expand(shape) - - def backward(self, grad_output:UOp) -> UOp: - return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype) + def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.expand(shape) class Reshape(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: - self.input_shape = x.shape - return x.reshape(shape) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape) + def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.reshape(shape) class Permute(Function): - def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: - self.input_order = order - return x.permute(order) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order)) + def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: return x.permute(order) class Pad(Function): - def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: - self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) - return x.pad(arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg) + def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: return x.pad(arg) class Shrink(Function): - def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: - self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) - return x.shrink(arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg) + def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: return x.shrink(arg) class Flip(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))]) - return x.stride(self.arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg) + def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.stride(tuple([-1 if i in axis else 1 for i in range(len(x.shape))])) From 06b58aa7ecb4d86f70f2b70fa0d828cf0ba2f270 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:36:15 -0500 Subject: [PATCH 33/44] move unneeded fields out of ScheduleContext [pr] (#8752) --- tinygrad/engine/schedule.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 801ffdcfd9..c1aa937e36 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -84,7 +84,6 @@ class ScheduleContext: contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) - becomes_map: dict[UOp, UOp] = field(default_factory=dict) # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. @@ -473,7 +472,6 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) for x in op.base.src: if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None - buf_uop.buffer.ref(1) create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) # **** movement ops @@ -502,24 +500,27 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) - # preschedule realize groups + # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] + becomes_map: dict[UOp, UOp] = {} for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: - for tensor_uop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + # increment refcount for this buffer + buf_uop.buffer.ref(1) # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed for k,v in tensor_map.items(): # NOOP if k.base is v.base: continue # NOTE: only the base tensors get a BUFFER UOp - if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) + if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st)) # otherwise if it simplified to a CONST the UOp just becomes that CONST - elif v.op is Ops.CONST: ctx.becomes_map[k] = v + elif v.op is Ops.CONST: becomes_map[k] = v # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} @@ -548,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # confirm everything was scheduled correctly if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") - return schedule, ctx.var_vals, ctx.becomes_map + return schedule, ctx.var_vals, becomes_map From b53fe7c2fcf2408eeb07e41f33f688ab24e31697 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 17:59:15 +0900 Subject: [PATCH 34/44] remove unused ctx [pr] (#8751) * remove unused ctx [pr] * fix test --- test/test_gc.py | 2 +- tinygrad/tensor.py | 12 +----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/test/test_gc.py b/test/test_gc.py index 0929a81394..cf90dc6201 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -33,7 +33,7 @@ class TestGC(unittest.TestCase): base = tensors_allocated() a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated()-base == 5) + assert (tensors_allocated()-base == 4) (a*b).mean().backward() assert (tensors_allocated()-base == 6) del b diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 826ade9242..1bcd22af53 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -53,14 +53,12 @@ class Function: self.metadata = metadata def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") @classmethod def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) ret = Tensor.__new__(Tensor) ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None - ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine return ret import tinygrad.function as F @@ -147,8 +145,7 @@ class Tensor(SimpleMathTrait): np.set_printoptions(precision=4) ``` """ - __slots__ = "lazydata", "requires_grad", "grad", "_ctx" - __deletable__ = ('_ctx',) + __slots__ = "lazydata", "requires_grad", "grad" training: ClassVar[bool] = False no_grad: ClassVar[bool] = False @@ -171,9 +168,6 @@ class Tensor(SimpleMathTrait): # None (the default) will be updated to True if it's put in an optimizer self.requires_grad: Optional[bool] = requires_grad - # internal variable used for autograd graph construction - self._ctx: Optional[Function] = None - # create a LazyBuffer from the different types of inputs if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" @@ -281,7 +275,6 @@ class Tensor(SimpleMathTrait): Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match. """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) - assert getattr(self, '_ctx', None) is None assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}" self.lazydata = x.lazydata return self @@ -378,7 +371,6 @@ class Tensor(SimpleMathTrait): """ ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.clone() - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor: @@ -390,7 +382,6 @@ class Tensor(SimpleMathTrait): if not isinstance(device, str): return self.shard(device) ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to_(self, device:Optional[Union[str, tuple[str, ...]]]): @@ -944,7 +935,6 @@ class Tensor(SimpleMathTrait): tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] # clear contexts - for t in tensors_need_grad: t._ctx = None for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" t.grad = g if t.grad is None else (t.grad + g) From ac70f63d4bad41048f0b39b2a866fa90ce97cd83 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 04:41:54 -0500 Subject: [PATCH 35/44] tensor_map cleanups [pr] (#8754) * tensor_map cleanups [pr] * update test_schedule too --- test/test_schedule.py | 4 ++-- tinygrad/engine/schedule.py | 34 +++++++++++++++++----------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index dee2a7ae78..969842af8f 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym +from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext()) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {}) class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c1aa937e36..f8f06b821b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -81,7 +81,6 @@ class ScheduleContext: realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata - contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) @@ -353,12 +352,12 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: case _: return None return reduce.const_like(ret) -def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp): - if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti) -def replace_contiguous(ctx:ScheduleContext, alu:UOp): +def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) +def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): new_src = list(alu.src) for i,s in enumerate(alu.src): - if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src + if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) sym = symbolic_simple+PatternMatcher([ @@ -490,11 +489,22 @@ remove_movement_ops = merge_views+PatternMatcher([ @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={}) + # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed + becomes_map: dict[UOp, UOp] = {} + for k,v in tensor_map.items(): + # NOOP + if k.base is v.base: continue + # NOTE: only the base tensors get a BUFFER UOp + if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st)) + # otherwise if it simplified to a CONST the UOp just becomes that CONST + elif v.op is Ops.CONST: becomes_map[k] = v + + # we group the rest of UOps into ScheduleItems rev_tensor_map: dict[UOp, list[UOp]] = {} for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops - sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={}) + sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) # add realizes sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # group realizes into kernels @@ -502,7 +512,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu graph_rewrite(sink, break_sched, ctx) # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] - becomes_map: dict[UOp, UOp] = {} for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) if not all(x.op is Ops.STORE for x in small_sink.src): raise RuntimeError(f"expected all realized BUFFERs to get a STORE {sink}") @@ -513,15 +522,6 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # increment refcount for this buffer buf_uop.buffer.ref(1) - # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed - for k,v in tensor_map.items(): - # NOOP - if k.base is v.base: continue - # NOTE: only the base tensors get a BUFFER UOp - if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st)) - # otherwise if it simplified to a CONST the UOp just becomes that CONST - elif v.op is Ops.CONST: becomes_map[k] = v - # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list) From a6e496b1950732ebc127f33b5ded1bd08f309b1b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 26 Jan 2025 18:58:02 +0900 Subject: [PATCH 36/44] remove Function class [pr] (#8753) * remove Function class [pr] * actually remove function * fix docs --- docs/developer/developer.md | 2 +- docs/developer/function.md | 33 ----------- mkdocs.yml | 1 - tinygrad/function.py | 108 ---------------------------------- tinygrad/tensor.py | 112 +++++++++++++++++------------------- 5 files changed, 53 insertions(+), 203 deletions(-) delete mode 100644 docs/developer/function.md delete mode 100644 tinygrad/function.py diff --git a/docs/developer/developer.md b/docs/developer/developer.md index b40d715af0..39e9e0901b 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not ## Frontend -Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md). +Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md). The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base. diff --git a/docs/developer/function.md b/docs/developer/function.md deleted file mode 100644 index 9f1b85f8cd..0000000000 --- a/docs/developer/function.md +++ /dev/null @@ -1,33 +0,0 @@ -::: tinygrad.function - options: - members: [ - "Contiguous", - "ContiguousBackward", - "Cast", - "Neg", - "Reciprocal", - "Sin", - "Relu", - "Log", - "Exp", - "Sqrt", - "Sigmoid", - "Sign", - "Less", - "Eq", - "Xor", - "Add", - "Sub", - "Mul", - "Div", - "Where", - "Sum", - "Max", - "Expand", - "Reshape", - "Permute", - "Pad", - "Shrink", - "Flip", - ] - show_source: false diff --git a/mkdocs.yml b/mkdocs.yml index 291998dac5..38419a5708 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,6 @@ nav: - Runtime: runtime.md - Developer: - Intro: developer/developer.md - - Function (autodiff): developer/function.md - UOp: developer/uop.md - Runtime: - developer/runtime.md diff --git a/tinygrad/function.py b/tinygrad/function.py deleted file mode 100644 index af5cecb8eb..0000000000 --- a/tinygrad/function.py +++ /dev/null @@ -1,108 +0,0 @@ -"""This is where the forwards and backwards passes live.""" -import math -from tinygrad.dtype import DType -from tinygrad.ops import Ops, sint, UOp -from tinygrad.tensor import Function - -class Contiguous(Function): - def forward(self, x:UOp) -> UOp: return x.contiguous() - -class ContiguousBackward(Function): - def forward(self, x:UOp) -> UOp: return x.contiguous_backward() - -class Cast(Function): - def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: return x.bitcast(dtype) if bitcast else x.cast(dtype) - -# ************* unary ops ************* - -class Reciprocal(Function): - def forward(self, x:UOp) -> UOp: return x.reciprocal() - -class Sin(Function): - def forward(self, x:UOp) -> UOp: return x.sin() - -class Relu(Function): - def forward(self, x:UOp) -> UOp: return (x>0).where(x, 0) - -class Log(Function): - def forward(self, x:UOp) -> UOp: return x.log2() * math.log(2) - -class Exp(Function): - def forward(self, x:UOp) -> UOp: return (x * (1/math.log(2))).exp2() - -class Sqrt(Function): - def forward(self, x:UOp) -> UOp: return x.sqrt() - -class Sign(Function): - # NOTE: the x*0 is to match torch behavior without function.py - def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0 - -# ************* binary ops ************* - -class Less(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x UOp: return x.ne(y) - -class Xor(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x^y - -class BitwiseAnd(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x&y - -class BitwiseOr(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x|y - -class Threefry(Function): - def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed) - -class Add(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x+y - -class Mul(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x * y - -class IDiv(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x // y - -class Mod(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x % y - -# ************* ternary ops ************* - -class Where(Function): - def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: return x.where(y, z) - - -# ************* reduce ops ************* - -class Sum(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.ADD, axis) - -class Prod(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MUL, axis) - -class Max(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.r(Ops.MAX, axis) - -# ************* movement ops ************* - -# NOTE: this is sum in reverse -class Expand(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.expand(shape) - -class Reshape(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: return x.reshape(shape) - -class Permute(Function): - def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: return x.permute(order) - -class Pad(Function): - def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: return x.pad(arg) - -class Shrink(Function): - def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: return x.shrink(arg) - -class Flip(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: return x.stride(tuple([-1 if i in axis else 1 for i in range(len(x.shape))])) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 1bcd22af53..1b81c74601 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex +from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap @@ -42,26 +42,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: if s is ns: continue t.lazydata = ns -# **** start with two base classes, Tensor and Function **** - -class Function: - def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None): - self.device = device - self.needs_input_grad = [t.requires_grad for t in tensors] - self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False - if self.requires_grad: self.parents = tensors - self.metadata = metadata - - def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - - @classmethod - def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: - ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) - ret = Tensor.__new__(Tensor) - ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None - return ret - -import tinygrad.function as F +# **** Tensor helper functions **** def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) @@ -239,6 +220,17 @@ class Tensor(SimpleMathTrait): @property def dtype(self) -> DType: return self.lazydata.dtype + def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor: + ret = Tensor.__new__(Tensor) + needs_input_grad = [t.requires_grad for t in (self,)+x] + ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None + ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs) + return ret + + def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + lhs,rhs = self._broadcasted(x, reverse) + return lhs._apply_uop(fxn, rhs) + # ***** data handlers **** def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: @@ -497,7 +489,7 @@ class Tensor(SimpleMathTrait): @staticmethod def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) + x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -961,7 +953,7 @@ class Tensor(SimpleMathTrait): # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self + return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: """ @@ -994,7 +986,7 @@ class Tensor(SimpleMathTrait): """ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") - return F.Permute.apply(self, order=order_arg) + return self._apply_uop(UOp.permute, arg=order_arg) def flip(self, axis, *args) -> Tensor: """ @@ -1014,7 +1006,7 @@ class Tensor(SimpleMathTrait): """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") - return F.Flip.apply(self, axis=axis_arg) + return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))])) def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor: """ @@ -1034,7 +1026,7 @@ class Tensor(SimpleMathTrait): ``` """ if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self - return F.Shrink.apply(self, arg=tuple(shrink_arg)) + return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor: """ @@ -1078,7 +1070,8 @@ class Tensor(SimpleMathTrait): if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": - def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v) + def _constant(x:Tensor,px,v): + return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v)) return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value) assert all_int(self.shape), f"does not support symbolic shape {self.shape}" @@ -1568,10 +1561,10 @@ class Tensor(SimpleMathTrait): # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: + def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) if self.ndim == 0: axis = () - ret = fxn.apply(self, axis=axis) + ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1598,7 +1591,7 @@ class Tensor(SimpleMathTrait): print(t.sum(axis=1).numpy()) ``` """ - ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim) + ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1625,7 +1618,7 @@ class Tensor(SimpleMathTrait): print(t.prod(axis=1).numpy()) ``` """ - return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim) + return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ @@ -1648,7 +1641,7 @@ class Tensor(SimpleMathTrait): print(t.max(axis=1, keepdim=True).numpy()) ``` """ - return self._reduce(F.Max, axis, keepdim) + return self._reduce(Ops.MAX, axis, keepdim) def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() @@ -2485,7 +2478,7 @@ class Tensor(SimpleMathTrait): print(Tensor([False, True]).logical_not().numpy()) ``` """ - return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True)) + return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) def neg(self): """ Negates the tensor element-wise. @@ -2499,12 +2492,12 @@ class Tensor(SimpleMathTrait): """ Returns a contiguous tensor. """ - return F.Contiguous.apply(self) + return self._apply_uop(UOp.contiguous) def contiguous_backward(self): """ Inserts a contiguous operation in the backward pass. """ - return F.ContiguousBackward.apply(self) + return self._apply_uop(UOp.contiguous_backward) def log(self): """ Computes the natural logarithm element-wise. @@ -2515,7 +2508,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log().numpy()) ``` """ - return F.Log.apply(self.cast(least_upper_float(self.dtype))) + return self.log2()*math.log(2) def log2(self): """ Computes the base-2 logarithm element-wise. @@ -2526,7 +2519,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log2().numpy()) ``` """ - return self.log()/math.log(2) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) def exp(self): """ Computes the exponential function element-wise. @@ -2537,7 +2530,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp().numpy()) ``` """ - return F.Exp.apply(self.cast(least_upper_float(self.dtype))) + return self.mul(1/math.log(2)).exp2() def exp2(self): """ Computes the base-2 exponential function element-wise. @@ -2548,8 +2541,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp2().numpy()) ``` """ - return F.Exp.apply(self*math.log(2)) - + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) def relu(self): """ Applies the Rectified Linear Unit (ReLU) function element-wise. @@ -2560,7 +2552,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) ``` """ - return F.Relu.apply(self) + return (self>0).where(self, 0) def sigmoid(self): """ @@ -2596,7 +2588,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) ``` """ - return F.Sqrt.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) def rsqrt(self): """ Computes the reciprocal of the square root of the tensor element-wise. @@ -2614,7 +2606,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) ``` """ - return F.Sin.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) def cos(self): """ Computes the cosine of the tensor element-wise. @@ -2773,7 +2765,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) ``` """ - return F.Sign.apply(self) + return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 def abs(self): """ Computes the absolute value of the tensor element-wise. @@ -2791,7 +2783,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) ``` """ - return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal) # ***** activation functions ***** @@ -3069,7 +3061,7 @@ class Tensor(SimpleMathTrait): # for each dimension, check either dim is 1, or it does not change if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") - return F.Expand.apply(self.reshape(shape), shape=new_shape) + return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape) def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: x: Tensor = self @@ -3113,7 +3105,7 @@ class Tensor(SimpleMathTrait): print(t.add(Tensor([[2.0], [3.5]])).numpy()) ``` """ - return F.Add.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.add, x, reverse) def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3154,7 +3146,7 @@ class Tensor(SimpleMathTrait): print(t.mul(Tensor([[-1.0], [2.0]])).numpy()) ``` """ - return F.Mul.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.mul, x, reverse) def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3167,7 +3159,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ - return F.IDiv.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.idiv, x, reverse) def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3202,7 +3194,7 @@ class Tensor(SimpleMathTrait): ``` """ a, b = self._broadcasted(x, reverse) - return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) + return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3218,7 +3210,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.Xor.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.xor, x, reverse) def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3233,7 +3225,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseAnd.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse) def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3248,7 +3240,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseOr.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse) def bitwise_not(self) -> Tensor: """ @@ -3379,7 +3371,7 @@ class Tensor(SimpleMathTrait): elif isinstance(y, Tensor): y, x = y._broadcasted(x) cond, x = self._broadcasted(x, match_dtype=False) cond, y = cond._broadcasted(y, match_dtype=False) - return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y)) + return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self) @@ -3409,9 +3401,9 @@ class Tensor(SimpleMathTrait): def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) - def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False)) - def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True)) - def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) + def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False) + def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True) + def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False) def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override] @@ -3757,8 +3749,8 @@ class Tensor(SimpleMathTrait): """ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype): # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around - return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt) - return self if self.dtype == dt else F.Cast.apply(self, dtype=dt) + return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt) + return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt) def bitcast(self, dtype:DTypeLike) -> Tensor: """ @@ -3783,7 +3775,7 @@ class Tensor(SimpleMathTrait): tmp = self.bitcast(old_uint) if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype) return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype) - return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self + return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self def float(self) -> Tensor: """ From bbb2dd8141ca5d6033c61a91b1f34da1fb1f13fc Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 09:58:05 -0500 Subject: [PATCH 37/44] move VALID creation after merging the views (#8757) * do valid creation later * work for view_left * only view(const) makes valids in view_left * cleaner bind diff --- tinygrad/engine/schedule.py | 13 +++++-------- tinygrad/ops.py | 5 +++++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f8f06b821b..3b6828688d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -199,6 +199,8 @@ to_si = PatternMatcher([ (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + # CONST(VIEW) becomes VALID too, TODO: doesn't have to + (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel @@ -438,11 +440,11 @@ do_realize = PatternMatcher([ (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), ]) -# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp +# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): - assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg + assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" + ctx.var_vals[ret:=var.replace(src=())] = val.const_arg return ret.valid(unwrap(bind.st)) def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): @@ -456,8 +458,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) break_sched = PatternMatcher([ - # CONST is always fused and generated - (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), (UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable), # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), @@ -481,9 +481,6 @@ remove_movement_ops = merge_views+PatternMatcher([ # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) (UPat(Ops.VIEW, name="view"), lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), - # merge unmasked const views - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), - lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) @track_rewrites(named=True) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 950589a80a..7f30ebe5a0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1322,10 +1322,15 @@ merge_views = PatternMatcher([ # VIEW(VIEW) merges to a single VIEW (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), + # merge unmasked const views + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) # push VIEW to parents view_left = merge_views+PatternMatcher([ + # VIEW(CONST) becomes VALID + (UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)), # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), From d488bbb1ecd95f20eea87bf19c3814feddef3355 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 26 Jan 2025 10:41:54 -0500 Subject: [PATCH 38/44] share merge_views/valid creation for CONST/DEFINE_VAR (#8758) * share valid creation behavior for CONST/DEFINE_VAR * work --- tinygrad/engine/schedule.py | 6 +++--- tinygrad/ops.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3b6828688d..ae30efe400 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -200,7 +200,7 @@ to_si = PatternMatcher([ # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), # CONST(VIEW) becomes VALID too, TODO: doesn't have to - (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), + (UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel @@ -444,8 +444,8 @@ do_realize = PatternMatcher([ def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[ret:=var.replace(src=())] = val.const_arg - return ret.valid(unwrap(bind.st)) + ctx.var_vals[var.replace(src=())] = val.const_arg + return var def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7f30ebe5a0..731244918b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1323,14 +1323,14 @@ merge_views = PatternMatcher([ (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), # merge unmasked const views - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) # push VIEW to parents view_left = merge_views+PatternMatcher([ # VIEW(CONST) becomes VALID - (UPat(Ops.VIEW, name="vm", src=(UPat.cvar("x"),)), lambda vm,x: UOp.const(x.dtype, x.const_arg).valid(vm.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)), # VIEW before elementwise/buffer ops (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), From 2454bf01c331fceba235df0eb4ec05fe0be74586 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 27 Jan 2025 07:20:21 +0900 Subject: [PATCH 39/44] hotfix: remove shapetracker spam in viz --- tinygrad/shape/shapetracker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 4fff017555..6897a6cc68 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -115,8 +115,9 @@ class ShapeTracker: def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def axis_is_masked(self, axis:int) -> bool: - _, valid = self.to_indexed_uops() - return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] + with Context(TRACK_MATCH_STATS=0): + _, valid = self.to_indexed_uops() + return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: From a9d9f98d05fc01ffb745dd28d04995dfcbfcab5c Mon Sep 17 00:00:00 2001 From: George Hotz Date: Mon, 27 Jan 2025 07:53:48 +0900 Subject: [PATCH 40/44] hotfix: those tests fail locally on mac due to buffer count --- test/models/test_train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/models/test_train.py b/test/models/test_train.py index 6020a8e777..605e6f6de1 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -40,6 +40,7 @@ class TestTrain(unittest.TestCase): check_gc() @unittest.skipIf(CI, "slow") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) @@ -56,6 +57,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_transformer(self): # this should be small GPT-2, but the param count is wrong # (real ff_dim is 768*4) From efc79710906bbb0df7ffc96e52cfce5a62dc9a5f Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Mon, 27 Jan 2025 13:53:21 +0800 Subject: [PATCH 41/44] add windows test to ci (#8761) Co-authored-by: b1tg --- .github/workflows/test.yml | 38 ++++++++++++++++++++++++++++++++++++++ tinygrad/device.py | 8 ++++---- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index be82930b62..537cc04e6e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -619,6 +619,44 @@ jobs: if: matrix.backend=='amd' run: python -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20 + wintests: + strategy: + fail-fast: false + matrix: + backend: [llvm] + + name: Tests on Windows (${{ matrix.backend }}) + runs-on: windows-latest + timeout-minutes: 45 + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: 3.12 + - name: Cache python packages + uses: actions/cache@v4 + with: + path: ${{ env.Python3_ROOT_DIR }}\Lib\site-packages + key: windows-${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }} + - name: Install dependencies + run: pip install --user -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu + - name: Check Device.DEFAULT and print some source + env: + DEBUG: 5 + LLVM: 1 + PYTHONPATH: ${{ github.workspace }} + run: | + python3 test/test_ops.py TestOps.test_add + - name: Run pytest + env: + DEBUG: 5 + LLVM: 1 + run: python -m pytest -n=auto test/test_tiny.py --durations=20 + #testunicorn: # name: ARM64 unicorn Test # runs-on: ubuntu-latest diff --git a/tinygrad/device.py b/tinygrad/device.py index 7fea67aa00..04182a3c33 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, replace from collections import defaultdict from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time -from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ cpu_time_execution, colored, Context from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes @@ -221,9 +220,10 @@ MAP_JIT = 0x0800 # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to class CPUProgram: - helper_handle = ctypes.CDLL(ctypes.util.find_library('System') if OSX else 'libgcc_s.so.1') - + helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32' if sys.platform == "win32" else 'gcc_s')) def __init__(self, name:str, lib:bytes): + assert sys.platform != "win32", "clang is not supported for windows yet" + from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/ # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np) self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC) @@ -329,4 +329,4 @@ if __name__ == "__main__": result = f"{colored('FAIL', 'yellow')} {e}" except Exception as e: result = f"{colored('FAIL', 'red')} {e}" - print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}") \ No newline at end of file + print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}") From 96bff0b4f74c33eef42833f93bd5bd65cf63acc3 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 27 Jan 2025 15:19:11 +0900 Subject: [PATCH 42/44] contiguous is no longer needed in SGD [pr] (#8760) * contiguous is no longer needed in SGD [pr] * add allow condition --- test/test_schedule.py | 12 ++++++------ tinygrad/engine/schedule.py | 9 ++++++--- tinygrad/nn/optim.py | 3 --- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/test_schedule.py b/test/test_schedule.py index 969842af8f..f79f5bd378 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]: + for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -1070,7 +1070,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters(c1)) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 5) + check_schedule(opt.schedule_step(), 3) def test_sgd_2convs_fuse(self): with Tensor.train(): @@ -1081,7 +1081,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 9) + check_schedule(opt.schedule_step(), 7) def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): @@ -1092,7 +1092,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 11) + check_schedule(opt.schedule_step(), 9) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -1105,7 +1105,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 21) + check_schedule(opt.schedule_step(), 17) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): @@ -1118,7 +1118,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 18) + with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ae30efe400..c94f18302b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -222,10 +222,13 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: assign_preloads[x.buf_uop] = None # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous: + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass # if it has a single view and it's equal when you shrink a contig, it's fine - if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask): - raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass + # otherwise, it's not fine + else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) # capture process replay if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index db1d84b345..b7cb9f4359 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -77,9 +77,6 @@ class LARS(Optimizer): def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]: for i, (t, g) in enumerate(zip(self.params, grads)): - # contiguous is needed since the grads can allegedly form a "diamond" - # TODO: fix this in lazy.py - g = g.contiguous() if self.tcoef != 0: r1 = t.detach().square().sum().sqrt() r2 = g.square().sum().sqrt() From bf041659a5004cc40294366254837ab9f129f296 Mon Sep 17 00:00:00 2001 From: Ignacio Sica Date: Mon, 27 Jan 2025 11:36:47 -0300 Subject: [PATCH 43/44] rename Opt amt to arg (#8767) --- extra/mcts_search.py | 2 +- test/external/external_debug_metal_sd_conv.py | 2 +- .../external_test_hcq_fuzz_failures.py | 2 +- test/external/external_test_nv.py | 4 +- test/external/external_test_train_gpt2.py | 4 +- test/external/external_test_valid_remove.py | 4 +- test/test_arange.py | 8 +- test/test_hcq.py | 2 +- test/test_linearizer.py | 32 ++--- test/test_linearizer_dumb.py | 12 +- test/test_linearizer_failures.py | 122 +++++++++--------- test/test_linearizer_overflows.py | 18 +-- test/test_search.py | 10 +- tinygrad/codegen/kernel.py | 14 +- tinygrad/engine/search.py | 22 ++-- 15 files changed, 129 insertions(+), 129 deletions(-) diff --git a/extra/mcts_search.py b/extra/mcts_search.py index 54189fabb3..9090c902ff 100644 --- a/extra/mcts_search.py +++ b/extra/mcts_search.py @@ -162,7 +162,7 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel: if node.n == 0: return for parent in node.parents: G.add_edge(parent, node) gopts = node.kernel.applied_opts - edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].amt}" if len(gopts) else "ROOT" + edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT" G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}", fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '') if node.children is not None: diff --git a/test/external/external_debug_metal_sd_conv.py b/test/external/external_debug_metal_sd_conv.py index 2e9315c5ba..f0f3db7971 100644 --- a/test/external/external_debug_metal_sd_conv.py +++ b/test/external/external_debug_metal_sd_conv.py @@ -29,7 +29,7 @@ ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()), x17,)),)),)),)) -opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)] +opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)] k = Kernel(ast) for opt in opts: k.apply_opt(opt) diff --git a/test/external/external_test_hcq_fuzz_failures.py b/test/external/external_test_hcq_fuzz_failures.py index 9370b7661d..3f434e65bb 100644 --- a/test/external/external_test_hcq_fuzz_failures.py +++ b/test/external/external_test_hcq_fuzz_failures.py @@ -55,7 +55,7 @@ class TestHCQFuzzFailures(unittest.TestCase): def test_failure_1(self): ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=1, src=()), x39:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=((0, 1), (0, 6)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), x39,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=3, src=()), x46:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-6, mask=((0, 1), (6, 12)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), x46,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=5, src=()), x54:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (12, 13)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), x54,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-13, mask=((0, 1), (13, 17)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-17, mask=((0, 1), (17, 21)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=9, src=()), x68:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (21, 22)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), x68,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-22, mask=((0, 1), (22, 26)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-26, mask=((0, 1), (26, 30)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=13, src=()), x82:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (30, 31)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), x82,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=15, src=()), x90:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (31, 32)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), x90,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=17, src=()), x98:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (32, 33)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), x98,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=19, src=()), x106:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (33, 34)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=20, src=()), x106,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=21, src=()), x114:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (34, 35)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=22, src=()), x114,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=23, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-35, mask=((0, 1), (35, 39)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=24, src=()), x125:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (39, 40)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=25, src=()), x125,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=26, src=()), x133:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (40, 41)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=27, src=()), x133,)),)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=28, src=()), x140:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-41, mask=((0, 1), (41, 47)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=29, src=()), x140,)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=30, src=()), x147:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-47, mask=((0, 1), (47, 53)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=31, src=()), x147,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=32, src=()), x155:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (53, 54)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=33, src=()), x155,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=34, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-54, mask=((0, 1), (54, 58)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=35, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-58, mask=((0, 1), (58, 62)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=36, src=()), x169:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (62, 63)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=37, src=()), x169,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=38, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-63, mask=((0, 1), (63, 67)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=39, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-67, mask=((0, 1), (67, 71)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=40, src=()), x183:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (71, 72)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=41, src=()), x183,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=42, src=()), x191:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (72, 73)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=43, src=()), x191,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=44, src=()), x199:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (73, 74)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=45, src=()), x199,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=46, src=()), x207:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (74, 75)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=47, src=()), x207,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=48, src=()), x215:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (75, 76)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=49, src=()), x215,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=50, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-76, mask=((0, 1), (76, 80)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=51, src=()), x226:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (80, 81)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=52, src=()), x226,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=53, src=()), x234:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (81, 82)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=54, src=()), x234,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=55, src=()), x243:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (82, 83)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=56, src=()), x243,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=57, src=()), x250:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (83, 84)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=58, src=()), x250,)),)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 128, 4)), arg=59, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-84, mask=((0, 1), (84, 596)), contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], validate_device=Device["GPU"]) if __name__ == '__main__': diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index 26b976e99b..702f9c805c 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -26,12 +26,12 @@ class TestNV(unittest.TestCase): def test_oor_kernels(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501 + opts = [Opt(op=OptOps.TC, axis=6, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501 helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"]) def test_error_on_huge_dims(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501 + opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501 with self.assertRaises(RuntimeError) as cm: lin = Kernel(ast) for opt in opts: lin.apply_opt(opt) diff --git a/test/external/external_test_train_gpt2.py b/test/external/external_test_train_gpt2.py index f59b878253..df7546e6b2 100644 --- a/test/external/external_test_train_gpt2.py +++ b/test/external/external_test_train_gpt2.py @@ -26,7 +26,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(38633472), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(0, 0, 768, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) run_linearizer(kernel) @@ -46,7 +46,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(205852672), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(51463168, 50257, 1, 0), offset=0, mask=((0, 4), (0, 1024), (0, 50257), (0, 768)), contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=4)] + opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) run_linearizer(kernel) diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index 1467139daf..d3cc22fca9 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -51,7 +51,7 @@ class TestOpenpilotValidhack(unittest.TestCase): x19,)), x29,)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] + opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) @@ -108,7 +108,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) diff --git a/test/test_arange.py b/test/test_arange.py index 2229ac847f..4d0bb79dc5 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -41,7 +41,7 @@ class TestArange(unittest.TestCase): def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1) @unittest.skip("doesn't work yet") - def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)]) + def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)]) def test_all_opts(self, opts=None, exclude=None): k = Kernel(Tensor.arange(256).schedule()[-1].ast) @@ -59,11 +59,11 @@ class TestArange(unittest.TestCase): self.test_complexity(opts) def test_all_opts_w_local(self): with contextlib.suppress(KernelOptError): - return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, amt=32)]) + return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, arg=32)]) def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)]) - def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)]) + def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) def test_all_opts_w_upcast_and_unroll(self): - return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)]) + return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) class TestIndexing(unittest.TestCase): # update: passing after CAST_BEFORE_VIEW=1 deletion diff --git a/test/test_hcq.py b/test/test_hcq.py index 476354129a..13ef71bd87 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -160,7 +160,7 @@ class TestHCQ(unittest.TestCase): b = a + 1 si = b.schedule()[-1] k = Kernel(si.ast, opts=TestHCQ.d0.renderer) - for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=3)) + for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, arg=3)) runner = CompiledRunner(k.to_program()) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c0bdc2c6f8..55b4463a58 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1300,10 +1300,10 @@ class TestLinearizer(unittest.TestCase): helper_linearizer_opt(a, [ [Opt(OptOps.GROUP, 0, 32)], [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(op=OptOps.LOCAL, axis=0, amt=8)], - [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0)], - [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8)], - [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], # noqa: E501 + [Opt(op=OptOps.LOCAL, axis=0, arg=8)], + [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)], + [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501 ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -1363,8 +1363,8 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [ - Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), - Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2) + Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), + Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2) ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is Ops.STORE][0] @@ -1381,9 +1381,9 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 - opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), - Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), - Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] + opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), + Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), + Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is Ops.STORE][0] assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1 @@ -1606,9 +1606,9 @@ class TestFloat4(unittest.TestCase): # TODO: fix this, expected might change but should be positive for expected, opts in [ - ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), - ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), - ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]), + ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) @@ -1637,8 +1637,8 @@ class TestFloat4(unittest.TestCase): UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 for expected, opts in [ - (1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]), - (4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]), + (1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]), + (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) @@ -1660,8 +1660,8 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 for expected, opts in [ - (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 - (4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]), + (16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501 + (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 5cd581aa04..2408bc2626 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -35,7 +35,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.half, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)] + opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] k = Kernel(ast, opts=Device["METAL"].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -70,7 +70,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 1000, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -88,7 +88,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -155,7 +155,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() @@ -186,7 +186,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() @@ -210,7 +210,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)] + opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 8570dac761..53c6ca5e5d 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -68,7 +68,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] + opts = [Opt(op=OptOps.LOCAL, axis=0, arg=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_3(self): @@ -80,7 +80,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=32)] # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -101,7 +101,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)] # EXEC_ERROR, it has no global_size helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -116,7 +116,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 10, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0)] # COMPILE FAILED, KeyError: UOps.CONST helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -129,7 +129,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] # test/test_linearizer_failures.py Fatal Python error: Segmentation fault helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -156,7 +156,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 1e-06, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)] # fatal error: bracket nesting level exceeded maximum of 256 # note: use -fbracket-depth=N to increase maximum nesting level helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -174,7 +174,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_10(self): @@ -290,7 +290,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.GROUP, axis=0, arg=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @unittest.skip("found implicit expand") @@ -315,7 +315,7 @@ class TestLinearizerFailures(unittest.TestCase): x6,)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( x5,)),)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.GROUP, axis=0, arg=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # both kernels are correct from a code standpoint, but generate different results due to precision errors (switching to float results in output matches) @@ -336,7 +336,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=19584, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) def test_failure_14(self): @@ -357,7 +357,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] # COMPILE_ERROR on METAL in fuzz_linearizer: unused variables and undeclared variables helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -395,7 +395,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=16)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=16)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 115: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -411,7 +411,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 0.0009765625, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4)] # COMPILE_ERROR on METAL/GPU (probably HIP/CUDA too) in fuzz_linearizer ast 154: bracket nesting level exceeded maximum of 256 helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -428,7 +428,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(188160, 0, 0, 784, 28, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=1, amt=4)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.GROUPTOP, axis=0, arg=16), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=1, arg=4)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 178: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -453,7 +453,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUPTOP, axis=0, arg=256), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 239: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -470,7 +470,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(252, 0, 0, 63, 7, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3)] + opts = [Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=7), Opt(op=OptOps.UPCAST, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 379: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -485,7 +485,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=0)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_21(self): @@ -495,7 +495,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, amt=32)] + opts = [Opt(op=OptOps.PADTO, axis=0, arg=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) #@unittest.skipIf(Device.DEFAULT in ("LLVM", "METAL", "CLANG"), "flaky") @@ -603,7 +603,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_24(self): @@ -614,7 +614,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # this is the cause of the GPT2 BEAM instability. bisects to PR#3530 O(n) arange attempt @@ -629,7 +629,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UNROLL, axis=0, arg=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # COMPARE_ERROR from GPT2 kernel - stems from uops.py self.simplify_phi_loops @@ -645,17 +645,17 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) all_failing_opts = [ - [Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0)], - [Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4)], - [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0)], - [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], - [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4)], - [Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], - [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)], - [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], - [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=4)], - [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], - [Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.GROUPTOP, axis=0, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=0)], + [Opt(op=OptOps.GROUPTOP, axis=0, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4)], + [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4)], + [Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)], + [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], + [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=4)], + [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0)], ] for opts in all_failing_opts: helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -683,7 +683,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) all_failing_opts = [ - [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=7), Opt(op=OptOps.UPCAST, axis=0, arg=0)], ] for opts in all_failing_opts: helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -740,7 +740,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=1), Opt(op=OptOps.PADTO, axis=2, amt=32)] + opts = [Opt(op=OptOps.TC, axis=0, arg=1), Opt(op=OptOps.PADTO, axis=2, arg=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0) def test_failure_30(self): @@ -758,7 +758,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(0, 0, 12, 0, 0, 4, 2, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=3, amt=32), Opt(op=OptOps.LOCAL, axis=3, amt=32), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0)] + opts = [Opt(op=OptOps.PADTO, axis=3, arg=32), Opt(op=OptOps.LOCAL, axis=3, arg=32), Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from METAL=1 fuzz_linearizer command in test.yml @@ -779,7 +779,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.4426950408889634, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] + opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipIf(CI, "for real AMD GPU") @@ -800,7 +800,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=16)] + opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05) def test_failure_33(self): @@ -841,7 +841,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), x10,)),)),)),)),)) - opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=16)] + opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from fuzzing on metal @@ -861,7 +861,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] if unroll else [Opt(op=OptOps.TC, axis=0, amt=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_35(self): self.test_failure_34(True) @@ -881,7 +881,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), ast_const(dtypes.uint, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # BEGIN METAL=1 ./examples/beautiful_mnist.py failures @@ -910,7 +910,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_38(self): @@ -930,7 +930,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) for axis in [0,1,3,4]: - opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skip("very slow, similar to test_failure_37") @@ -958,7 +958,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] + opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_40(self): @@ -975,7 +975,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for amt in [16,32]: - opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)] + opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=amt), Opt(op=OptOps.UNROLL, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # END METAL=1 ./examples/beautiful_mnist.py failures @@ -996,7 +996,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] + opts=[Opt(op=OptOps.TC, axis=5, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @@ -1011,7 +1011,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=0, amt=32)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.PADTO, axis=0, arg=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @@ -1025,7 +1025,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @@ -1039,7 +1039,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) assert k is not None ifs = [u for u in k.uops if u.op is Ops.IF] @@ -1084,7 +1084,7 @@ class TestLinearizerFailures(unittest.TestCase): x19,)),)), x21,)),)),)),)),)),)),)) # ValueError: size mismatched, can't reshape self.shape=(6, 2, 3, 3) -> new_shape=(6, 2, 3, 1, 2) - opts = [Opt(op=OptOps.UNROLL, axis=2, amt=0)] + opts = [Opt(op=OptOps.UNROLL, axis=2, arg=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_46(self): @@ -1117,7 +1117,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_47(self): @@ -1132,7 +1132,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=3)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=3)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(not CI and Device.DEFAULT in ("NV", "CUDA"), "for real NV") @@ -1151,7 +1151,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_49(self): @@ -1168,7 +1168,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_50(self): @@ -1195,7 +1195,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)] + opts = [Opt(op=OptOps.UPCAST, axis=1, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_51(self): @@ -1236,7 +1236,7 @@ class TestLinearizerFailures(unittest.TestCase): x6, UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()), x9,)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=2)] + opts = [Opt(op=OptOps.TC, axis=0, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1258,7 +1258,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16)] + opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_53(self): @@ -1294,7 +1294,7 @@ class TestLinearizerFailures(unittest.TestCase): x22, UOp(Ops.CONST, dtypes.bool, arg=True, src=()), UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.GROUPTOP, axis=1, amt=16)] + opts = [Opt(op=OptOps.GROUPTOP, axis=1, arg=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["AMD", "GPU", "METAL", "NV", "CUDA"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1315,7 +1315,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] + opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1336,7 +1336,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] + opts = [Opt(op=OptOps.SWAP, axis=1, arg=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_56(self): @@ -1382,7 +1382,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=2, amt=32)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) def test_failure_57(self): @@ -1428,7 +1428,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] + opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) if __name__ == '__main__': diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 2e7265f652..f5fb749956 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -59,7 +59,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x16,)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0)] _test_overflow(ast, opts) # From BEAM on hlb_cifar.py @@ -76,7 +76,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)] _test_overflow(ast, opts) # from BEAM on default simple_conv.py (which is quite large): @@ -93,7 +93,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] _test_overflow(ast, opts) # from BEAM on BS=4 simple_conv.py: @@ -110,7 +110,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] + opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] _test_overflow(ast, opts) # from BEAM on BS=2 simple_conv.py: @@ -127,7 +127,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: @@ -144,7 +144,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: (alt) @@ -161,7 +161,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] + opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] _test_overflow(ast, opts) @unittest.skipIf(Device.DEFAULT not in {"GPU", "HSA", "CUDA", "METAL"}, "only backends with locals") @@ -177,7 +177,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)] _test_overflow(ast, opts) def test_overflow_2(self): BS = 2 @@ -189,7 +189,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) - opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=4, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=5, arg=2)] _test_overflow(ast, opts) if __name__ == '__main__': diff --git a/test/test_search.py b/test/test_search.py index d22e03bc59..d0d6cf9114 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -92,15 +92,15 @@ class TestBEAM(unittest.TestCase): # ensure amt=0 are not duplicated if Opt(OptOps.UPCAST, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, arg=4)]) == 0, "did not de-dup UPCAST" if Opt(OptOps.LOCAL, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, arg=4)]) == 0, "did not de-dup LOCAL" if Opt(OptOps.UNROLL, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, arg=3)]) == 0, "did not de-dup UNROLL" if Opt(OptOps.GROUP, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, arg=3)]) == 0, "did not de-dup GROUP" if Opt(OptOps.GROUPTOP, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP" def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index bfa5b54c31..e45f371b90 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -32,8 +32,8 @@ def check(cond:bool, msg:str=""): class Opt: op: OptOps axis: Optional[int] = None - amt: Optional[int] = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" + arg: Optional[int] = None + def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" def real_axis(self, k:Kernel): if self.axis is None: return -1 if self.op is OptOps.UNROLL: return k.first_reduce+self.axis @@ -353,18 +353,18 @@ class Kernel: if opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine - check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt") + check(opt.axis is not None and opt.arg is not None, "tensor core opts must have an axis and arg") check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2") - check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available") + check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.arg)), "no tensor core available") self.applied_opts.append(opt) return axis = opt.real_axis(self) check(axis < len(self.full_shape), "invalid axis") - if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs - elif opt.amt is not None: - amt = opt.amt if opt.amt != 0 else self.full_shape[axis] + if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs + elif opt.arg is not None: + amt = opt.arg if opt.arg != 0 else self.full_shape[axis] check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless") if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") else: amt = -1 diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index ecd5662bce..8ef92ba6ed 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -11,16 +11,16 @@ from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import ProgramSpec -actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] -actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)] -actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] -actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] -actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] -if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] -actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)] -actions += [Opt(op=OptOps.TC, axis=0, amt=0)] -actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) -actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)] +actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] +actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] +actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] +actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] +actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)] +if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] +actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)] +actions += [Opt(op=OptOps.TC, axis=0, arg=0)] +actions += [Opt(op=OptOps.TC, axis=axis, arg=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) +actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals): @@ -104,7 +104,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) for i,a in enumerate(actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() try: lin2.apply_opt(a) From 3ed146a5ff93d85405c76949ada38a717aa4676f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 27 Jan 2025 23:46:37 +0900 Subject: [PATCH 44/44] Revert "rename Opt amt to arg (#8767)" (#8769) This reverts commit bf041659a5004cc40294366254837ab9f129f296. --- extra/mcts_search.py | 2 +- test/external/external_debug_metal_sd_conv.py | 2 +- .../external_test_hcq_fuzz_failures.py | 2 +- test/external/external_test_nv.py | 4 +- test/external/external_test_train_gpt2.py | 4 +- test/external/external_test_valid_remove.py | 4 +- test/test_arange.py | 8 +- test/test_hcq.py | 2 +- test/test_linearizer.py | 32 ++--- test/test_linearizer_dumb.py | 12 +- test/test_linearizer_failures.py | 122 +++++++++--------- test/test_linearizer_overflows.py | 18 +-- test/test_search.py | 10 +- tinygrad/codegen/kernel.py | 14 +- tinygrad/engine/search.py | 22 ++-- 15 files changed, 129 insertions(+), 129 deletions(-) diff --git a/extra/mcts_search.py b/extra/mcts_search.py index 9090c902ff..54189fabb3 100644 --- a/extra/mcts_search.py +++ b/extra/mcts_search.py @@ -162,7 +162,7 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel: if node.n == 0: return for parent in node.parents: G.add_edge(parent, node) gopts = node.kernel.applied_opts - edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].arg}" if len(gopts) else "ROOT" + edge_lbl = f"{str(gopts[-1].op)[7:]} {gopts[-1].axis} {gopts[-1].amt}" if len(gopts) else "ROOT" G.add_node(node, label=f"{node.i+1}\n{node.tm:.2f} us\n{edge_lbl}\nt {node.t:.2f}\nn {node.n}", fillcolor="#80ff8080" if node.tm == best_tm else "#ffff8080", style='filled' if node.t == best_tm else '') if node.children is not None: diff --git a/test/external/external_debug_metal_sd_conv.py b/test/external/external_debug_metal_sd_conv.py index f0f3db7971..2e9315c5ba 100644 --- a/test/external/external_debug_metal_sd_conv.py +++ b/test/external/external_debug_metal_sd_conv.py @@ -29,7 +29,7 @@ ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=4, src=()), x17,)),)),)),)) -opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)] +opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.LOCAL, axis=2, amt=2)] k = Kernel(ast) for opt in opts: k.apply_opt(opt) diff --git a/test/external/external_test_hcq_fuzz_failures.py b/test/external/external_test_hcq_fuzz_failures.py index 3f434e65bb..9370b7661d 100644 --- a/test/external/external_test_hcq_fuzz_failures.py +++ b/test/external/external_test_hcq_fuzz_failures.py @@ -55,7 +55,7 @@ class TestHCQFuzzFailures(unittest.TestCase): def test_failure_1(self): ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=1, src=()), x39:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=((0, 1), (0, 6)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), x39,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=3, src=()), x46:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-6, mask=((0, 1), (6, 12)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), x46,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=5, src=()), x54:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (12, 13)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), x54,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-13, mask=((0, 1), (13, 17)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-17, mask=((0, 1), (17, 21)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=9, src=()), x68:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (21, 22)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), x68,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-22, mask=((0, 1), (22, 26)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-26, mask=((0, 1), (26, 30)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=13, src=()), x82:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (30, 31)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), x82,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=15, src=()), x90:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (31, 32)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), x90,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=17, src=()), x98:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (32, 33)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), x98,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=19, src=()), x106:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (33, 34)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=20, src=()), x106,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=21, src=()), x114:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (34, 35)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=22, src=()), x114,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=23, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-35, mask=((0, 1), (35, 39)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=24, src=()), x125:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (39, 40)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=25, src=()), x125,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=26, src=()), x133:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (40, 41)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=27, src=()), x133,)),)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=28, src=()), x140:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-41, mask=((0, 1), (41, 47)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=29, src=()), x140,)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=30, src=()), x147:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-47, mask=((0, 1), (47, 53)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=31, src=()), x147,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=32, src=()), x155:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (53, 54)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=33, src=()), x155,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=34, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-54, mask=((0, 1), (54, 58)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=35, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-58, mask=((0, 1), (58, 62)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=36, src=()), x169:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (62, 63)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=37, src=()), x169,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=38, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-63, mask=((0, 1), (63, 67)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=39, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-67, mask=((0, 1), (67, 71)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=40, src=()), x183:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (71, 72)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=41, src=()), x183,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=42, src=()), x191:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (72, 73)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=43, src=()), x191,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=44, src=()), x199:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (73, 74)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=45, src=()), x199,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=46, src=()), x207:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (74, 75)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=47, src=()), x207,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=48, src=()), x215:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (75, 76)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=49, src=()), x215,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=50, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-76, mask=((0, 1), (76, 80)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=51, src=()), x226:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (80, 81)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=52, src=()), x226,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=53, src=()), x234:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (81, 82)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=54, src=()), x234,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=55, src=()), x243:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (82, 83)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=56, src=()), x243,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=57, src=()), x250:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (83, 84)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=58, src=()), x250,)),)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 128, 4)), arg=59, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-84, mask=((0, 1), (84, 596)), contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], validate_device=Device["GPU"]) if __name__ == '__main__': diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index 702f9c805c..26b976e99b 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -26,12 +26,12 @@ class TestNV(unittest.TestCase): def test_oor_kernels(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - opts = [Opt(op=OptOps.TC, axis=6, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2)] # noqa: E501 + opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501 helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"]) def test_error_on_huge_dims(self): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501 + opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501 with self.assertRaises(RuntimeError) as cm: lin = Kernel(ast) for opt in opts: lin.apply_opt(opt) diff --git a/test/external/external_test_train_gpt2.py b/test/external/external_test_train_gpt2.py index df7546e6b2..f59b878253 100644 --- a/test/external/external_test_train_gpt2.py +++ b/test/external/external_test_train_gpt2.py @@ -26,7 +26,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(38633472), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(0, 0, 768, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=3), Opt(op=OptOps.LOCAL, axis=0, amt=2)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) run_linearizer(kernel) @@ -46,7 +46,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(205852672), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1024, 50304, 768), strides=(51463168, 50257, 1, 0), offset=0, mask=((0, 4), (0, 1024), (0, 50257), (0, 768)), contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)] + opts = [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=4)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) run_linearizer(kernel) diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index d3cc22fca9..1467139daf 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -51,7 +51,7 @@ class TestOpenpilotValidhack(unittest.TestCase): x19,)), x29,)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] + opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) @@ -108,7 +108,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.imagef((1, 128, 4)), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=0, mask=((0, 1), (9, 10), (0, 512)), contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.NOLOCALS, axis=None, amt=None)] kernel = Kernel(ast) for opt in opts: kernel.apply_opt(opt) diff --git a/test/test_arange.py b/test/test_arange.py index 4d0bb79dc5..2229ac847f 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -41,7 +41,7 @@ class TestArange(unittest.TestCase): def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1) @unittest.skip("doesn't work yet") - def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)]) + def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)]) def test_all_opts(self, opts=None, exclude=None): k = Kernel(Tensor.arange(256).schedule()[-1].ast) @@ -59,11 +59,11 @@ class TestArange(unittest.TestCase): self.test_complexity(opts) def test_all_opts_w_local(self): with contextlib.suppress(KernelOptError): - return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, arg=32)]) + return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, amt=32)]) def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)]) - def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) + def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)]) def test_all_opts_w_upcast_and_unroll(self): - return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, arg=0)]) + return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)]) class TestIndexing(unittest.TestCase): # update: passing after CAST_BEFORE_VIEW=1 deletion diff --git a/test/test_hcq.py b/test/test_hcq.py index 13ef71bd87..476354129a 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -160,7 +160,7 @@ class TestHCQ(unittest.TestCase): b = a + 1 si = b.schedule()[-1] k = Kernel(si.ast, opts=TestHCQ.d0.renderer) - for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, arg=3)) + for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=3)) runner = CompiledRunner(k.to_program()) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 55b4463a58..c0bdc2c6f8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1300,10 +1300,10 @@ class TestLinearizer(unittest.TestCase): helper_linearizer_opt(a, [ [Opt(OptOps.GROUP, 0, 32)], [Opt(OptOps.GROUPTOP, 0, 32)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8)], - [Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], # noqa: E501 + [Opt(op=OptOps.LOCAL, axis=0, amt=8)], + [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8)], + [Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], # noqa: E501 ]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.supports_float4, "test requires float4") @@ -1363,8 +1363,8 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [ - Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), - Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2) + Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), + Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2) ] k = helper_linearizer_ast(ast, [Tensor.randn(240*40).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is Ops.STORE][0] @@ -1381,9 +1381,9 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 - opt = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), - Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), - Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), + Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), + Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] k = helper_linearizer_ast(ast, [Tensor.randn(8*32).realize()], opts=[opt])[-1] out = [u for u in k.uops if u.op is Ops.STORE][0] assert out.src[-1].op is Ops.VECTORIZE and out.src[-1].dtype.count != 1 @@ -1606,9 +1606,9 @@ class TestFloat4(unittest.TestCase): # TODO: fix this, expected might change but should be positive for expected, opts in [ - ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)]), - ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), + ((7, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), + ((5, 0), [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)]), + ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, amt=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) @@ -1637,8 +1637,8 @@ class TestFloat4(unittest.TestCase): UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 for expected, opts in [ - (1, [Opt(op=OptOps.UPCAST, axis=2, arg=4)]), - (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), + (1, [Opt(op=OptOps.UPCAST, axis=2, amt=4)]), + (4, [Opt(op=OptOps.UPCAST, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) @@ -1660,8 +1660,8 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 for expected, opts in [ - (16, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=3, arg=4)]), # noqa: E501 - (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), + (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 + (4, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2)]), ]: k = Kernel(ast) for opt in opts: k.apply_opt(opt) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 2408bc2626..5cd581aa04 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -35,7 +35,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.half, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] + opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=0)] k = Kernel(ast, opts=Device["METAL"].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -70,7 +70,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 1000, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -88,7 +88,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() for opt in opts: k.apply_opt(opt) @@ -155,7 +155,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() @@ -186,7 +186,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() @@ -210,7 +210,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)] + opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) for opt in opts: k.apply_opt(opt) prg = k.to_program() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 53c6ca5e5d..8570dac761 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -68,7 +68,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=0, arg=32)] + opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_3(self): @@ -80,7 +80,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=32)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -101,7 +101,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] # EXEC_ERROR, it has no global_size helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -116,7 +116,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 10, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)] # COMPILE FAILED, KeyError: UOps.CONST helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -129,7 +129,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=4)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] # test/test_linearizer_failures.py Fatal Python error: Segmentation fault helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -156,7 +156,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 1e-06, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)] # fatal error: bracket nesting level exceeded maximum of 256 # note: use -fbracket-depth=N to increase maximum nesting level helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -174,7 +174,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_10(self): @@ -290,7 +290,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.GROUP, axis=0, arg=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @unittest.skip("found implicit expand") @@ -315,7 +315,7 @@ class TestLinearizerFailures(unittest.TestCase): x6,)), UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( x5,)),)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.GROUP, axis=0, arg=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # both kernels are correct from a code standpoint, but generate different results due to precision errors (switching to float results in output matches) @@ -336,7 +336,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=19584, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) def test_failure_14(self): @@ -357,7 +357,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] + opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] # COMPILE_ERROR on METAL in fuzz_linearizer: unused variables and undeclared variables helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -395,7 +395,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=16)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=16)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 115: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -411,7 +411,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 0.0009765625, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=1, amt=4)] # COMPILE_ERROR on METAL/GPU (probably HIP/CUDA too) in fuzz_linearizer ast 154: bracket nesting level exceeded maximum of 256 helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -428,7 +428,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(188160, 0, 0, 784, 28, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.GROUPTOP, axis=0, arg=16), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.LOCAL, axis=1, arg=4)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=1, amt=4)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 178: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -453,7 +453,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUPTOP, axis=0, arg=256), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=3)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 239: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -470,7 +470,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(252, 0, 0, 63, 7, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=7), Opt(op=OptOps.UPCAST, axis=2, arg=3), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=3)] + opts = [Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 379: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -485,7 +485,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=0)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_21(self): @@ -495,7 +495,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()), ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=0, arg=32)] + opts = [Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) #@unittest.skipIf(Device.DEFAULT in ("LLVM", "METAL", "CLANG"), "flaky") @@ -603,7 +603,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.LOCAL, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) def test_failure_24(self): @@ -614,7 +614,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=2), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # this is the cause of the GPT2 BEAM instability. bisects to PR#3530 O(n) arange attempt @@ -629,7 +629,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UNROLL, axis=0, arg=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) # COMPARE_ERROR from GPT2 kernel - stems from uops.py self.simplify_phi_loops @@ -645,17 +645,17 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) all_failing_opts = [ - [Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.GROUPTOP, axis=0, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=0)], - [Opt(op=OptOps.GROUPTOP, axis=0, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4)], - [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4)], - [Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=4)], - [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4)], - [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=4)], - [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.GROUP, axis=0, arg=16), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)], - [Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0)], + [Opt(op=OptOps.GROUPTOP, axis=0, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4)], + [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4)], + [Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=4)], + [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4)], + [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.UNROLL, axis=1, amt=4)], + [Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.GROUP, axis=0, amt=16), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)], + [Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UNROLL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=0, amt=0)], ] for opts in all_failing_opts: helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -683,7 +683,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) all_failing_opts = [ - [Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=7), Opt(op=OptOps.UPCAST, axis=0, arg=0)], + [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=0)], ] for opts in all_failing_opts: helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -740,7 +740,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=1), Opt(op=OptOps.PADTO, axis=2, arg=32)] + opts = [Opt(op=OptOps.TC, axis=0, amt=1), Opt(op=OptOps.PADTO, axis=2, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0) def test_failure_30(self): @@ -758,7 +758,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(0, 0, 12, 0, 0, 4, 2, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.PADTO, axis=3, arg=32), Opt(op=OptOps.LOCAL, axis=3, arg=32), Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0)] + opts = [Opt(op=OptOps.PADTO, axis=3, amt=32), Opt(op=OptOps.LOCAL, axis=3, amt=32), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from METAL=1 fuzz_linearizer command in test.yml @@ -779,7 +779,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.4426950408889634, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)] + opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipIf(CI, "for real AMD GPU") @@ -800,7 +800,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=16)] + opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05) def test_failure_33(self): @@ -841,7 +841,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 1.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=None, contiguous=False),)), src=()),)),)), x10,)),)),)),)),)) - opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=16)] + opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # from fuzzing on metal @@ -861,7 +861,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] if unroll else [Opt(op=OptOps.TC, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] if unroll else [Opt(op=OptOps.TC, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_35(self): self.test_failure_34(True) @@ -881,7 +881,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), ast_const(dtypes.uint, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # BEGIN METAL=1 ./examples/beautiful_mnist.py failures @@ -910,7 +910,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_38(self): @@ -930,7 +930,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) for axis in [0,1,3,4]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skip("very slow, similar to test_failure_37") @@ -958,7 +958,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for axis in [0,1,2,3,4,5]: - opts = [Opt(op=OptOps.TC, axis=axis, arg=2)] + opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_40(self): @@ -975,7 +975,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) for amt in [16,32]: - opts = [Opt(op=OptOps.GROUPTOP, axis=0, arg=amt), Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) # END METAL=1 ./examples/beautiful_mnist.py failures @@ -996,7 +996,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts=[Opt(op=OptOps.TC, axis=5, arg=2), Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @@ -1011,7 +1011,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=2), Opt(op=OptOps.PADTO, axis=0, arg=32)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @@ -1025,7 +1025,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") @@ -1039,7 +1039,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) - opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)] + opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) assert k is not None ifs = [u for u in k.uops if u.op is Ops.IF] @@ -1084,7 +1084,7 @@ class TestLinearizerFailures(unittest.TestCase): x19,)),)), x21,)),)),)),)),)),)),)) # ValueError: size mismatched, can't reshape self.shape=(6, 2, 3, 3) -> new_shape=(6, 2, 3, 1, 2) - opts = [Opt(op=OptOps.UNROLL, axis=2, arg=0)] + opts = [Opt(op=OptOps.UNROLL, axis=2, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_46(self): @@ -1117,7 +1117,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) def test_failure_47(self): @@ -1132,7 +1132,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=3)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=3)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @unittest.skipUnless(not CI and Device.DEFAULT in ("NV", "CUDA"), "for real NV") @@ -1151,7 +1151,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_49(self): @@ -1168,7 +1168,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_50(self): @@ -1195,7 +1195,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=1, arg=2)] + opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_51(self): @@ -1236,7 +1236,7 @@ class TestLinearizerFailures(unittest.TestCase): x6, UOp(Ops.CONST, dtypes.half, arg=-1.4426950408889634, src=()), x9,)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2)] + opts = [Opt(op=OptOps.TC, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1258,7 +1258,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=0, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=16)] + opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_53(self): @@ -1294,7 +1294,7 @@ class TestLinearizerFailures(unittest.TestCase): x22, UOp(Ops.CONST, dtypes.bool, arg=True, src=()), UOp(Ops.CONST, dtypes.bool, arg=False, src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.GROUPTOP, axis=1, arg=16)] + opts = [Opt(op=OptOps.GROUPTOP, axis=1, amt=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["AMD", "GPU", "METAL", "NV", "CUDA"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1315,7 +1315,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.TC, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=7), Opt(op=OptOps.UPCAST, axis=1, arg=2)] + opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UPCAST, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["HIP", "AMD"]) @unittest.skipIf(CI and Device.DEFAULT in {"METAL"}, "hangs metal gpu CI") @@ -1336,7 +1336,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) - opts = [Opt(op=OptOps.SWAP, axis=1, arg=2)] + opts = [Opt(op=OptOps.SWAP, axis=1, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) def test_failure_56(self): @@ -1382,7 +1382,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=2, arg=32)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=2, amt=32)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) def test_failure_57(self): @@ -1428,7 +1428,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 16, 5, 2, 5, 2), strides=(1600, 100, 20, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(128, 16, 11, 11), strides=(1600, 100, 10, 1), offset=0, mask=((0, 128), (0, 16), (0, 10), (0, 10)), contiguous=False))), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32)] + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["METAL"]) if __name__ == '__main__': diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index f5fb749956..2e7265f652 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -59,7 +59,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x16,)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)] _test_overflow(ast, opts) # From BEAM on hlb_cifar.py @@ -76,7 +76,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] _test_overflow(ast, opts) # from BEAM on default simple_conv.py (which is quite large): @@ -93,7 +93,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=4 simple_conv.py: @@ -110,7 +110,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] + opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) # from BEAM on BS=2 simple_conv.py: @@ -127,7 +127,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: @@ -144,7 +144,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=3, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] _test_overflow(ast, opts) # from BEAM on BS=3 simple_conv.py: (alt) @@ -161,7 +161,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) - opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=2, arg=4)] + opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) @unittest.skipIf(Device.DEFAULT not in {"GPU", "HSA", "CUDA", "METAL"}, "only backends with locals") @@ -177,7 +177,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.LOCAL, axis=2, arg=2), Opt(op=OptOps.UPCAST, axis=0, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] _test_overflow(ast, opts) def test_overflow_2(self): BS = 2 @@ -189,7 +189,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) - opts = [Opt(op=OptOps.LOCAL, axis=3, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=2, arg=16), Opt(op=OptOps.UPCAST, axis=4, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=2), Opt(op=OptOps.UPCAST, axis=5, arg=2)] + opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)] _test_overflow(ast, opts) if __name__ == '__main__': diff --git a/test/test_search.py b/test/test_search.py index d0d6cf9114..d22e03bc59 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -92,15 +92,15 @@ class TestBEAM(unittest.TestCase): # ensure amt=0 are not duplicated if Opt(OptOps.UPCAST, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, arg=4)]) == 0, "did not de-dup UPCAST" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UPCAST, axis=0, amt=4)]) == 0, "did not de-dup UPCAST" if Opt(OptOps.LOCAL, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, arg=4)]) == 0, "did not de-dup LOCAL" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.LOCAL, axis=0, amt=4)]) == 0, "did not de-dup LOCAL" if Opt(OptOps.UNROLL, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, arg=3)]) == 0, "did not de-dup UNROLL" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.UNROLL, axis=0, amt=3)]) == 0, "did not de-dup UNROLL" if Opt(OptOps.GROUP, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, arg=3)]) == 0, "did not de-dup GROUP" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUP, axis=0, amt=3)]) == 0, "did not de-dup GROUP" if Opt(OptOps.GROUPTOP, 0, 0) in actions: - assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, arg=3)]) == 0, "did not de-dup GROUPTOP" + assert len([x for x in lins if x.applied_opts[0] == Opt(OptOps.GROUPTOP, axis=0, amt=3)]) == 0, "did not de-dup GROUPTOP" def test_filter_global_buffer(self): # taken from https://github.com/tinygrad/tinygrad/issues/4612 diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index e45f371b90..bfa5b54c31 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -32,8 +32,8 @@ def check(cond:bool, msg:str=""): class Opt: op: OptOps axis: Optional[int] = None - arg: Optional[int] = None - def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, arg={self.arg})" + amt: Optional[int] = None + def __repr__(self): return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})" def real_axis(self, k:Kernel): if self.axis is None: return -1 if self.op is OptOps.UNROLL: return k.first_reduce+self.axis @@ -353,18 +353,18 @@ class Kernel: if opt.op is OptOps.TC: check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: things like PADTO might be fine - check(opt.axis is not None and opt.arg is not None, "tensor core opts must have an axis and arg") + check(opt.axis is not None and opt.amt is not None, "tensor core opts must have an axis and amt") check((use_tensor_cores:=USE_TC.value) == 2 or len(self.opts.tensor_cores) > 0, "must have tensor cores or TC=2") - check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.arg)), "no tensor core available") + check(self._apply_tc_opt(use_tensor_cores, cast(int, opt.axis), cast(int, opt.amt)), "no tensor core available") self.applied_opts.append(opt) return axis = opt.real_axis(self) check(axis < len(self.full_shape), "invalid axis") - if opt.op is OptOps.SWAP: amt = cast(int, opt.arg) # arg is an axis in the SWAPs - elif opt.arg is not None: - amt = opt.arg if opt.arg != 0 else self.full_shape[axis] + if opt.op is OptOps.SWAP: amt = cast(int, opt.amt) # amt is an axis in the SWAPs + elif opt.amt is not None: + amt = opt.amt if opt.amt != 0 else self.full_shape[axis] check(isinstance(amt, int) and amt != 1, f"shift/padto of {amt=}, 1 or symbolic amount is meaningless") if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, f"no longer valid shift {self.full_shape[axis]=}, {amt=}") else: amt = -1 diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 8ef92ba6ed..ecd5662bce 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -11,16 +11,16 @@ from tinygrad.tensor import Tensor from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer import ProgramSpec -actions = [Opt(op=OptOps.UPCAST, axis=axis, arg=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] -actions += [Opt(op=OptOps.UNROLL, axis=axis, arg=amt) for amt in [0,4,7] for axis in range(5)] -actions += [Opt(op=OptOps.LOCAL, axis=axis, arg=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] -actions += [Opt(op=OptOps.GROUPTOP, axis=axis, arg=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] -actions += [Opt(op=OptOps.GROUP, axis=axis, arg=amt) for amt in [0,4,8,16] for axis in range(3)] -if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, arg=amt) for amt in [32] for axis in range(7)] -actions += [Opt(op=OptOps.LOCAL, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=6, arg=2)] -actions += [Opt(op=OptOps.TC, axis=0, arg=0)] -actions += [Opt(op=OptOps.TC, axis=axis, arg=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) -actions += [Opt(op=OptOps.SWAP, axis=axis_0, arg=axis_1) for axis_0 in range(5) for axis_1 in range(axis_0+1, 5)] +actions = [Opt(op=OptOps.UPCAST, axis=axis, amt=amt) for amt in [0,2,3,4,5,7] for axis in range(6)] +actions += [Opt(op=OptOps.UNROLL, axis=axis, amt=amt) for amt in [0,4,7] for axis in range(5)] +actions += [Opt(op=OptOps.LOCAL, axis=axis, amt=amt) for amt in [2,3,4,8,13,16,29] for axis in range(6)] +actions += [Opt(op=OptOps.GROUPTOP, axis=axis, amt=amt) for amt in [13,16,28,29,32,49,64,256] for axis in range(3)] +actions += [Opt(op=OptOps.GROUP, axis=axis, amt=amt) for amt in [0,4,8,16] for axis in range(3)] +if getenv("BEAM_PADTO", 1): actions += [Opt(op=OptOps.PADTO, axis=axis, amt=amt) for amt in [32] for axis in range(7)] +actions += [Opt(op=OptOps.LOCAL, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=6, amt=2)] +actions += [Opt(op=OptOps.TC, axis=0, amt=0)] +actions += [Opt(op=OptOps.TC, axis=axis, amt=getenv("TC_OPT", 2)) for axis in range(9)] # covers resnet kernels (3 global * 3 reduce) +actions += [Opt(op=OptOps.SWAP, axis=axis, amt=amt) for axis in range(5) for amt in range(axis+1, 5)] if getenv("NOLOCALS"): actions += [Opt(op=OptOps.NOLOCALS)] def _get_test_global_size(global_size, max_global_size, var_vals): @@ -104,7 +104,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: acted_lins, max_up, max_lcl = {0:lin} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) for i,a in enumerate(actions): if a.axis is not None and a.op is not OptOps.TC: - if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.arg and Opt(a.op, ax, 0) in actions): continue + if ((ax:=a.real_axis(lin)) >= lin.shape_len) or (lin.full_shape[ax] == a.amt and Opt(a.op, ax, 0) in actions): continue lin2 = lin.copy() try: lin2.apply_opt(a)