diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 265971b24a..cf5b117aad 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -299,6 +299,10 @@ jobs: run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt + - name: Run 10 MLPerf Bert training steps (6 gpu) + # TODO: remove DISABLE_DROPOUT once dropout is fixed + # TODO: remove BERT_LAYERS once scheduler is fast + run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (NVIDIA Training) @@ -309,9 +313,10 @@ jobs: train_cifar_bf16.txt train_cifar_wino.txt train_cifar_one_gpu.txt + train_cifar_six_gpu.txt train_resnet.txt train_resnet_one_gpu.txt - train_cifar_six_gpu.txt + train_bert.txt - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py @@ -492,6 +497,10 @@ jobs: run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt - name: Run 10 MLPerf ResNet50 training steps (6 gpu) run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt + - name: Run 10 MLPerf Bert training steps (6 gpu) + # TODO: remove DISABLE_DROPOUT once dropout is fixed + # TODO: remove BERT_LAYERS once scheduler is fast + run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt - uses: actions/upload-artifact@v4 with: name: Speed (AMD Training) @@ -502,9 +511,10 @@ jobs: train_cifar_bf16.txt train_cifar_wino.txt train_cifar_one_gpu.txt + train_cifar_six_gpu.txt train_resnet.txt train_resnet_one_gpu.txt - train_cifar_six_gpu.txt + train_bert.txt - name: Run process replay tests run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9fbc0bf5d9..4de691cc00 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -619,6 +619,44 @@ jobs: if: matrix.backend=='amd' run: python -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20 + wintests: + strategy: + fail-fast: false + matrix: + backend: [llvm] + + name: Tests on Windows (${{ matrix.backend }}) + runs-on: windows-latest + timeout-minutes: 45 + steps: + - name: Checkout Code + uses: actions/checkout@v4 + with: + fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: 3.12 + - name: Cache python packages + uses: actions/cache@v4 + with: + path: ${{ env.Python3_ROOT_DIR }}\Lib\site-packages + key: windows-${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }} + - name: Install dependencies + run: pip install --user -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu + - name: Check Device.DEFAULT and print some source + env: + DEBUG: 5 + LLVM: 1 + PYTHONPATH: ${{ github.workspace }} + run: | + python3 test/test_ops.py TestOps.test_add + - name: Run pytest + env: + DEBUG: 5 + LLVM: 1 + run: python -m pytest -n=auto test/test_tiny.py --durations=20 + #testunicorn: # name: ARM64 unicorn Test # runs-on: ubuntu-latest diff --git a/docs/developer/developer.md b/docs/developer/developer.md index b40d715af0..39e9e0901b 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not ## Frontend -Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md). +Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md). The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base. diff --git a/docs/developer/function.md b/docs/developer/function.md deleted file mode 100644 index 9f1b85f8cd..0000000000 --- a/docs/developer/function.md +++ /dev/null @@ -1,33 +0,0 @@ -::: tinygrad.function - options: - members: [ - "Contiguous", - "ContiguousBackward", - "Cast", - "Neg", - "Reciprocal", - "Sin", - "Relu", - "Log", - "Exp", - "Sqrt", - "Sigmoid", - "Sign", - "Less", - "Eq", - "Xor", - "Add", - "Sub", - "Mul", - "Div", - "Where", - "Sum", - "Max", - "Expand", - "Reshape", - "Permute", - "Pad", - "Shrink", - "Flip", - ] - show_source: false diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 78c59bdb18..d2008ef87d 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit from tinygrad.nn.state import get_state_dict, get_parameters from tinygrad.nn import optim from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod -from tinygrad.multi import MultiLazyBuffer cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618] cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628] @@ -35,8 +34,6 @@ class UnsyncedBatchNorm: self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False) def __call__(self, x:Tensor): - if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices - xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32) batch_mean, batch_invstd = self.calc_stats(xr) ret = xr.batchnorm( diff --git a/examples/llama3.py b/examples/llama3.py index e331573d1a..c24ec6ea2a 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -247,11 +247,11 @@ if __name__ == "__main__": fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr") args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr") elif args.size == "70B": - subdir = "Llama-3.1-Nemotron-70B-Instruct-HF" - args.model = fetch("https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir) + subdir = "DeepSeek-R1-Distill-Llama-70B" + args.model = fetch("https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir) fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=subdir) - for i in range(30): - fetch(f"https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model-{i+1:05d}-of-00030.safetensors?download=true", f"model-{i+1:05d}-of-00030.safetensors", subdir=subdir) + for i in range(17): + fetch(f"https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model-{i+1:05d}-of-000017.safetensors?download=true", f"model-{i+1:05d}-of-000017.safetensors", subdir=subdir) assert args.model is not None, "please provide --model option" diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index 4929005fa2..ec017b6f30 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -204,7 +204,7 @@ def get_mlperf_bert_config(): "intermediate_size": 4096, "max_position_embeddings": 512, "num_attention_heads": 16, - "num_hidden_layers": 24, + "num_hidden_layers": getenv("BERT_LAYERS", 24), "type_vocab_size": 2, "vocab_size": 30522 } diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 4cc03d112b..89dc238914 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -786,7 +786,8 @@ def train_rnnt(): pass @TinyJit -def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): +def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, + masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): optimizer.zero_grad() lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) @@ -802,18 +803,15 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te optimizer.step() scheduler.step() - return loss.realize() + return loss.realize(), global_norm.realize() @TinyJit -def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): +def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, + masked_lm_weights:Tensor, next_sentence_labels:Tensor): lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids) - masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) - return { - "masked_lm_accuracy": masked_lm_accuracy.realize(), - "next_sentence_accuracy": seq_relationship_accuracy.realize(), - "masked_lm_loss": masked_lm_loss.realize(), - "next_sentence_loss": next_sentence_loss.realize() - } + masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \ + model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) + return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize() def train_bert(): # NOTE: pip install tensorflow, wandb required @@ -949,7 +947,7 @@ def train_bert(): previous_step = None if ckpt:=getenv("RESUME", ""): load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt)) - start_step = int(scheduler_wd.epoch_counter.numpy().item()) + start_step = int(scheduler_wd.epoch_counter.item()) print(f"resuming from {ckpt} at step {start_step}") if RUNMLPERF: @@ -975,7 +973,7 @@ def train_bert(): BEAM.value = TRAIN_BEAM st = time.perf_counter() GlobalCounters.reset() - loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, + loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \ train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"]) @@ -992,7 +990,7 @@ def train_bert(): dt = time.perf_counter() device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}" - loss = loss.numpy().item() + loss = loss.item() cl = time.perf_counter() if BENCHMARK: step_times.append(cl - st) @@ -1002,7 +1000,7 @@ def train_bert(): f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, " f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS") if WANDB: - wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, + wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS}) @@ -1037,12 +1035,10 @@ def train_bert(): GlobalCounters.reset() st = time.time() - eval_result: dict[str, Tensor] = eval_step_bert(model, + lm_acc, clsf_acc, lm_loss, clsf_loss = eval_step_bert(model, eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"], eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"]) - - lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item() - lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item() + lm_acc, clsf_acc, lm_loss, clsf_loss = lm_acc.item(), clsf_acc.item(), lm_loss.item(), clsf_loss.item() eval_lm_losses.append(lm_loss) eval_clsf_losses.append(clsf_loss) @@ -1059,7 +1055,7 @@ def train_bert(): return if getenv("RESET_STEP", 1): eval_step_bert.reset() - del eval_data, eval_result + del eval_data avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses) avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses) avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs) diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh index 99b99f7e89..41cc7d8050 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh index 70a3b6a6cb..bde86667b9 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_run.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh index a213f4d682..18ca38d147 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/run_and_time.sh @@ -5,7 +5,7 @@ export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_green" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512 +export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh index 08ffd354b3..5bd4f06778 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_beam.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh index c42c1f65b6..530b490ebf 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/dev_run.sh @@ -4,7 +4,7 @@ export PYTHONPATH="." export MODEL="bert" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh index d6ff9fd2cc..4d993fc1c2 100755 --- a/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh +++ b/examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_red/run_and_time.sh @@ -5,7 +5,7 @@ export MODEL="bert" export SUBMISSION_PLATFORM="tinybox_red" export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36 -export BEAM=3 +export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024 export IGNORE_JIT_FIRST_BEAM=1 export BASEDIR="/raid/datasets/wiki" diff --git a/extra/amdpci/am_smi.py b/extra/amdpci/am_smi.py new file mode 100644 index 0000000000..f5ffbc6da3 --- /dev/null +++ b/extra/amdpci/am_smi.py @@ -0,0 +1,167 @@ +import time, mmap, sys, shutil, os, glob +from tinygrad.helpers import to_mv, DEBUG, colored, ansilen +from tinygrad.runtime.autogen import libc +from tinygrad.runtime.autogen.am import smu_v13_0_0 +from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager +from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA + +AM_VERSION = 0xA0000002 + +def bold(s): return f"\033[1m{s}\033[0m" + +def color_temp(temp): + if temp >= 87: return colored(f"{temp:>4}", "red") + elif temp >= 80: return colored(f"{temp:>4}", "yellow") + return f"{temp:>4}" + +def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan") + +def draw_bar(percentage, width=40, fill='█', empty='░'): + filled_width = int(width * percentage) + bar = fill * filled_width + empty * (width - filled_width) + return f'[{bar}] {percentage*100:5.1f}%' + +def same_line(strs:list[list[str]], split=8) -> list[str]: + ret = [] + max_width_in_block = [max(ansilen(line) for line in block) for block in strs] + max_height = max(len(block) for block in strs) + for i in range(max_height): + line = [] + for bid, block in enumerate(strs): + if i < len(block): line.append(block[i] + ' ' * (split + max_width_in_block[bid] - ansilen(block[i]))) + else: line.append(' ' * (split + max_width_in_block[bid])) + ret.append(' '.join(line)) + return ret + +def get_bar0_size(pcibus): + resource_file = f"/sys/bus/pci/devices/{pcibus}/resource" + if not os.path.exists(resource_file): raise FileNotFoundError(f"Resource file not found: {resource_file}") + + with open(resource_file, "r") as f: lines = f.readlines() + bar0_info = lines[0].split() + if len(bar0_info) < 3: raise ValueError("Unexpected resource file format for BAR0.") + + start_hex, end_hex, _flags = bar0_info + return int(end_hex, 16) - int(start_hex, 16) + 1 + +class AMSMI(AMDev): + def __init__(self, pcibus, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): + self.pcibus = pcibus + self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar + + self._run_discovery() + self._build_regs() + + if self.reg("regSCRATCH_REG7").read() != AM_VERSION: + raise Exception(f"Unsupported AM version: {self.reg('regSCRATCH_REG7').read():x}") + + self.is_booting, self.smi_dev = True, True + self.partial_boot = True # do not init anything + self.mm = AMMemoryManager(self, self.vram_size) + + # Initialize IP blocks + self.soc21:AM_SOC21 = AM_SOC21(self) + self.gmc:AM_GMC = AM_GMC(self) + self.ih:AM_IH = AM_IH(self) + self.psp:AM_PSP = AM_PSP(self) + self.smu:AM_SMU = AM_SMU(self) + +class SMICtx: + def __init__(self): + self.devs = [] + self.opened_pcidevs = [] + self.opened_pci_resources = {} + self.prev_lines_cnt = 0 + + def _open_am_device(self, pcibus): + if pcibus not in self.opened_pci_resources: + bar_fds = {bar: os.open(f"/sys/bus/pci/devices/{pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]} + bar_size = {0: get_bar0_size(pcibus), 2: os.fstat(bar_fds[2]).st_size, 5: os.fstat(bar_fds[5]).st_size} + + def map_pci_range(bar): + return to_mv(libc.mmap(0, bar_size[bar], mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, bar_fds[bar], 0), bar_size[bar]) + self.opened_pci_resources[pcibus] = (map_pci_range(0), None, map_pci_range(5).cast('I')) + + try: + self.devs.append(AMSMI(pcibus, *self.opened_pci_resources[pcibus])) + except Exception as e: + if DEBUG >= 2: print(f"Failed to open AM device {pcibus}: {e}") + return + + self.opened_pcidevs.append(pcibus) + if DEBUG >= 2: print(f"Opened AM device {pcibus}") + + def rescan_devs(self): + pattern = os.path.join('/tmp', 'am_*.lock') + for d in [f[8:-5] for f in glob.glob(pattern)]: + if d not in self.opened_pcidevs: + self._open_am_device(d) + + for d in self.devs: + if d.reg("regSCRATCH_REG7").read() != AM_VERSION: + self.devs.remove(d) + self.opened_pcidevs.remove(d.pcibus) + os.system('clear') + if DEBUG >= 2: print(f"Removed AM device {d.pcibus}") + + def collect(self): return {d: d.smu.read_metrics() for d in self.devs} + + def draw(self): + terminal_width, _ = shutil.get_terminal_size() + + dev_metrics = self.collect() + dev_content = [] + for dev, metrics in dev_metrics.items(): + device_line = [f"PCIe device: {bold(dev.pcibus)}"] + [""] + activity_line = [f"GFX Activity {draw_bar(metrics.SmuMetrics.AverageGfxActivity / 100, 50)}"] \ + + [f"UCLK Activity {draw_bar(metrics.SmuMetrics.AverageUclkActivity / 100, 50)}"] + [""] + + # draw_metrics_table(metrics, dev) + temps_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_TEMP_e__enumvalues.items() + if k < smu_v13_0_0.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0] + temps_table = ["=== Temps (C) ==="] + [f"{name:<15}: {color_temp(metrics.SmuMetrics.AvgTemperature[k])}" for k, name in temps_keys] + + voltage_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_SVI_PLANE_e__enumvalues.items() if k < smu_v13_0_0.SVI_PLANE_COUNT] + power_table = ["=== Power ==="] \ + + [f"Fan Speed: {metrics.SmuMetrics.AvgFanRpm} RPM"] \ + + [f"Fan Power: {metrics.SmuMetrics.AvgFanPwm} %"] \ + + [f"Power: {metrics.SmuMetrics.AverageSocketPower}W " + + draw_bar(metrics.SmuMetrics.AverageSocketPower / metrics.SmuMetrics.dGPU_W_MAX, 16)] \ + + ["", "=== Voltages ==="] + [f"{name:<24}: {color_voltage(metrics.SmuMetrics.AvgVoltage[k])}" for k, name in voltage_keys] + + frequency_table = ["=== Frequencies ===", + f"GFXCLK Target : {metrics.SmuMetrics.AverageGfxclkFrequencyTarget} MHz", + f"GFXCLK PreDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPreDs} MHz", + f"GFXCLK PostDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPostDs} MHz", + f"FCLK PreDs : {metrics.SmuMetrics.AverageFclkFrequencyPreDs} MHz", + f"FCLK PostDs : {metrics.SmuMetrics.AverageFclkFrequencyPostDs} MHz", + f"MCLK PreDs : {metrics.SmuMetrics.AverageMemclkFrequencyPreDs} MHz", + f"MCLK PostDs : {metrics.SmuMetrics.AverageMemclkFrequencyPostDs} MHz", + f"VCLK0 : {metrics.SmuMetrics.AverageVclk0Frequency} MHz", + f"DCLK0 : {metrics.SmuMetrics.AverageDclk0Frequency} MHz", + f"VCLK1 : {metrics.SmuMetrics.AverageVclk1Frequency} MHz", + f"DCLK1 : {metrics.SmuMetrics.AverageDclk1Frequency} MHz"] + + dev_content.append(device_line + activity_line + same_line([temps_table, power_table, frequency_table])) + + raw_text = 'AM Monitor'.center(terminal_width) + "\n" + "=" * terminal_width + "\n\n" + for i in range(0, len(dev_content), 2): + if i + 1 < len(dev_content): raw_text += '\n'.join(same_line([dev_content[i], dev_content[i+1]])) + else: raw_text += '\n'.join(dev_content[i]) + if i + 2 < len(dev_content): raw_text += "\n" + "=" * terminal_width + "\n\n" + + sys.stdout.write(f'\033[{self.prev_lines_cnt}A') + sys.stdout.flush() + print(raw_text) + + self.prev_lines_cnt = len(raw_text.splitlines()) + 2 + +if __name__ == "__main__": + try: + os.system('clear') + smi_ctx = SMICtx() + while True: + smi_ctx.rescan_devs() + smi_ctx.draw() + time.sleep(1) + except KeyboardInterrupt: print("Exiting...") diff --git a/extra/amdpci/setup_python_cap.sh b/extra/amdpci/setup_python_cap.sh new file mode 100755 index 0000000000..2ef3c5f6b7 --- /dev/null +++ b/extra/amdpci/setup_python_cap.sh @@ -0,0 +1,3 @@ +#!/bin/bash +PYTHON_PATH=$(readlink -f $(which python3)) +sudo setcap 'cap_dac_override,cap_sys_rawio,cap_sys_admin=ep' $PYTHON_PATH diff --git a/extra/amdpci/setup_vfio.sh b/extra/amdpci/setup_vfio.sh new file mode 100755 index 0000000000..6ca786edaa --- /dev/null +++ b/extra/amdpci/setup_vfio.sh @@ -0,0 +1,2 @@ +#!/bin/bash +sudo modprobe vfio-pci disable_idle_d3=1 diff --git a/extra/models/bert.py b/extra/models/bert.py index c1eb33f85c..01a136c8e4 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -63,15 +63,15 @@ class BertForPretraining: def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor): valid = masked_lm_ids != 0 - masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1) - masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid + masked_lm_predictions = prediction_logits.argmax(-1) + masked_lm_correct = (masked_lm_predictions == masked_lm_ids) * valid masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights) - seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1) - seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels) + seq_relationship_predictions = seq_relationship_logits.argmax(-1) + seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels) next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels) - return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss + return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 4f745e680b..165de3154b 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -1,41 +1,15 @@ import functools, io, math from typing import cast, Literal from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr -from tinygrad.dtype import ImageDType, dtypes +from tinygrad.dtype import ImageDType, dtypes, DType from tinygrad.helpers import prod, flatten, make_tuple from extra.onnx import dtype_parse, _cached_to_python_const import numpy as np -# **************** Free Ops **************** - +# ***** Property/Graph Ops ***** def Identity(x:Tensor): return x -# TODO: fix buffer_parse -def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype) -def Sub(x:Tensor|int, other:Tensor): return x - other # some test has input as int -def Less(x:Tensor,y:Tensor): return x < y -def LessOrEqual(x:Tensor,y:Tensor): return x <= y -def Greater(x:Tensor,y:Tensor): return x > y -def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y -def Equal(x:Tensor,y:Tensor): return x == y -def BitwiseNot(x:Tensor): return ~x -def BitwiseOr(x:Tensor, y:Tensor): return x | y -def BitwiseAnd(x:Tensor, y:Tensor): return x & y -def BitwiseXor(x:Tensor, y:Tensor): return x ^ y -def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) -def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) -def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) -def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) -def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) -def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) - -# **************** Simple Ops **************** - -# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py -def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype) - -def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, - value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None, - value_string:str|None=None, value_strings:list[str]|None=None): +def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None, + value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None): if value is not None: return value if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False) if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False) @@ -44,21 +18,79 @@ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float: if value_string is not None or value_strings is not None and sparse_value is not None: raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value') +def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) + +def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): + try: import PIL.Image + except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e + img = PIL.Image.open(io.BytesIO(encoded_stream)) + if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] + if pixel_format == "RGB": return Tensor(np.array(img)) + if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) + raise ValueError(f"pixel_format={pixel_format!r} is not supported.") + +def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): + ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) + return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) + +def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) +def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) +def ConstantOfShape(shape:list[int], value:Tensor|None=None): + if value is None: value = Tensor(0, dtype=dtypes.float32) + return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) + +def Size(data:Tensor): return data.numel() +def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) + +# ***** Unary Ops (math) ***** +def Not(x:Tensor): return x.logical_not() +def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): + return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) + +# ***** Unary Ops (activation) ***** +def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) +def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) +Softmax = {1:Softmax_1, 13:Softmax_13} def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) +def FastGelu(x:Tensor, bias:Tensor|None=None): + # this is tanh approximated + return (x + bias).gelu() if bias is not None else x.gelu() # TODO: fix this def PRelu(X:Tensor, slope:Tensor): - slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE + slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope return (X > 0).where(X, X * slope) def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha) def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0) -def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis) -def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis) -Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis) -def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): # noqa: A002 - return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype) +def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() +# ***** Unary Ops (broadcasted) ***** +def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype) +def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int +def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype) +def Less(x:Tensor,y:Tensor): return x < y +def LessOrEqual(x:Tensor,y:Tensor): return x <= y +def Greater(x:Tensor,y:Tensor): return x > y +def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y +def Equal(x:Tensor,y:Tensor): return x == y +def And(x:Tensor,y:Tensor): return (x==y).where(x, False) +def Or(x:Tensor,y:Tensor): return (x==y).where(x, True) +def BitwiseAnd(x:Tensor,y:Tensor): return x & y +def BitwiseOr(x:Tensor,y:Tensor): return x | y +def BitwiseXor(x:Tensor,y:Tensor): return x ^ y +def BitwiseNot(x:Tensor): return ~x + +# ***** Casting Ops ***** +# TODO: saturate +def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to)) +def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype) + +# ***** Reduce Ops ***** +def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0) +def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0) +def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0) +def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0) def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None) def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims) @@ -80,27 +112,27 @@ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_wit return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log() def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0): return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log() +def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): + if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) + return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) +def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): + return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) -def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) -def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0) -def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([]) - -def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) -def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta) -def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64) -def Size(data:Tensor): return data.numel() -def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) +# ***** Movement Ops ***** def Reshape(data:Tensor, shape:list[int], allowzero:int=0): return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)]) +def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1) def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape))) def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias) -def And(x:Tensor, y:Tensor): return (x==y).where(x, False) -def Or(x:Tensor, y:Tensor): return (x==y).where(x, True) -def Not(x:Tensor): return x.logical_not() +def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) -def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) +# TODO: add test for when axes is None +def Squeeze(data:Tensor, axes:list[int]|None=None): + return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) +def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) +def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats) +def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None): axes = axes or list(range(data.ndim)) steps = steps or [1]*data.ndim @@ -113,83 +145,9 @@ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0) if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)] return data.split(split, axis) -# TODO: add test for when axes is None -def Squeeze(data:Tensor, axes:list[int]|None=None): - return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data) -def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data) - -def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float() - -def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0): - if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64) - return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64) -def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index) - -def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis) -def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm) - -def ConstantOfShape(shape:list[int], value:Tensor|None=None): - if value is None: value = Tensor(0, dtype=dtypes.float32) - return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1) - -# **************** Complex Ops **************** - -def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): - ret = alpha * (A.transpose(transA) @ B.transpose(transB)) - if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) - return ret - -def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) - -def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): - axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) - if reverse: X = X.flip(axis) - if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ - .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) - return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) - -# TODO: this is copied from tinygrad/nn/__init__.py -# spatial is from opset 7 and has since been removed -def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, - training_mode:int=0, spatial=1, is_test=0): - if training_mode: - x_detached = X.detach() - current_mean = x_detached.mean(axis=(0,2,3)) - y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) - current_var = (y*y).mean(axis=(0,2,3)) - current_invstd = current_var.add(epsilon).rsqrt() - - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var - invstd = (input_var + epsilon).rsqrt() - return X.batchnorm(scale, B, input_mean, invstd) - -def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): - axis = tuple(range(2, x.ndim)) - mean = x.mean(axis=axis, keepdim=True) - invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() - return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) - -def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): - assert stash_type == 1, "only float32 is supported" - axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) - mean = x.mean(axis=axes, keepdim=True) - return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() - -def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): - return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) - -# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) -def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])])) - -AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] -# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) -def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): - if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] - return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] - +def _onnx_pads_to_tiny_pads(pads): + # (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...) + return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:]))))) def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None, mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0): value = constant_value or value @@ -198,6 +156,21 @@ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[ for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)] return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value) +def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): + shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim + pad_arg:list[None|tuple[int,int]] = [None] * t.ndim + for s, x in zip(shape, axes or range(t.ndim)): + tx = t.shape[x] + if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) + elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) + return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) + +# ***** Processing Ops ***** +AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"] +def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS): + # (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right) + if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))] + return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))] def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_)) if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2) @@ -206,25 +179,22 @@ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS): def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad) + return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), + ceil_mode=ceil_mode, count_include_pad=count_include_pad) def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0, storage_order:int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad) - ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode) + ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode) # tests expect indices with int64 dtype # TODO: if there are repeated values, this is wrong indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape) return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, - kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): - pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad) - return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads)) + kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1): + return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, + padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)) -# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose -# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1, kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0, strides:list[int]|int=1): @@ -247,80 +217,28 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shap if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER") return ret.pad(_onnx_pads_to_tiny_pads(pads)) -def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): - return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) -def SpaceToDepth(X:Tensor, blocksize:int): - return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) +def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True) +def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True) -# Reimplemented here because you need legacy RNG for passing ONNX tests. -def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): - if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. - mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) - return data * mask * (1/(1.0 - ratio)), mask -# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx -def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) -Dropout = {6:Dropout_6, 7:Dropout_7} +def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0): + ret = alpha * (A.transpose(transA) @ B.transpose(transB)) + if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1])) + return ret -def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): - pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) - return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) +def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs) -def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): - return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) +def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0): + axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis) + if reverse: X = X.flip(axis) + if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\ + .shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim))) + return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis) -def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - return x.nll_loss(target, weight, ignore_index, reduction) - -def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): - log_probs = scores.log_softmax(1) - return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs - -def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] - -def Gather(x:Tensor, indices:Tensor, axis:int=0): - if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices - x_sh = list(x.shape) - ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] - if indices.ndim > 1: indices = indices.flatten() - indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] - args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore - return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) - # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot - return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] -def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated - -def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): - if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] - x_shape, i_shape = x.shape, indices.shape - b = math.prod(x.shape[dim] for dim in range(batch_dims)) - # NOTE: each batched dim of both input and indices are equal - x = x.reshape(b, *x.shape[batch_dims:]) - indices = indices.reshape(b, *indices.shape[batch_dims:]) - b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) - ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] - return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) -def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): - assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] - x = x.contiguous() - for index, u in zip(indices.split(1, 0), updates.split(1, 0)): - i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) - u = u.squeeze(0) - if reduction == "none": x[i] = u - elif reduction == "add": x[i] += u - elif reduction == "mul": x[i] *= u - else: raise NotImplementedError("reduction doesn't support max or min") - return x - -def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) -def GatherElements(x:Tensor, indices:Tensor, axis:int): - indices = (indices < 0).where(x.shape[axis], 0) + indices - return x.gather(axis, indices) +def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k) def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0, - axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, - extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): + axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0, + extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'): def _apply_nearest_mode(index: Tensor, input_dim, mode: str): if mode == "round_prefer_floor": index = (index - 0.5).ceil() elif mode == "round_prefer_ceil": index = (index + 0.5).floor() @@ -377,116 +295,43 @@ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, si X = X.gather(i, low).lerp(X.gather(i, high), perc) if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X - -def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None): - shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim - pad_arg:list[None|tuple[int,int]] = [None] * t.ndim - for s, x in zip(shape, axes or range(t.ndim)): - tx = t.shape[x] - if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2) - elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2) - return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg)) - -def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): - # Scalar or Rank 1 tensor containing exactly one element - depth = int(depth[0] if isinstance(depth, list) else depth) - indices = (indices < 0).where(indices+depth, indices) - return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) - -def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): - if axis is None: - inp = inp.flatten() - axis = 0 - if axis < 0: axis += inp.ndim - con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor - return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] - -def EyeLike(x:Tensor, dtype:int|None=None, k:int=0): - ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype) - return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape)) - def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated -def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): - if axis < 0: axis += x.ndim - if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) - if block_size == 0: - shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) - return scale.reshape(shape), zero_point.reshape(shape) - return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) +# ***** Neural Network Ops ***** +# TODO: try to factor out common implementations for these ops +# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0 +def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9, + training_mode:int=0, spatial=1, is_test=0): + if training_mode: + x_detached = X.detach() + current_mean = x_detached.mean(axis=(0,2,3)) + y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1])) + current_var = (y*y).mean(axis=(0,2,3)) + current_invstd = current_var.add(epsilon).rsqrt() -def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): - out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 - y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) - return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous() - -def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): - x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) - return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) - -def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor): - assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops" - y = (y / y_scale + y_zero_point).round() - return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype) - -def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1, - kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1): - x = x.int() - x_zero_point - w = w.int() - w_zero_point - y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides) - y_scale = y_scale / (x_scale * w_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, - y_zero_point:Tensor|int) -> Tensor: - a = a.int() - a_zero_point - b = b.int() - b_zero_point - y = Tensor.matmul(a, b, acc_dtype=dtypes.int32) - y_scale = y_scale / (a_scale * b_scale) - return _quantize_linear(y, y_scale, y_zero_point) - -def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, - auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None, - pads: int | list[int] = 0, strides: int | list[int] = 1) -> Tensor: - x_int = x.int() - x_zero_point - w_int = w.int() - w_zero_point - return Conv(x_int, w_int, B, auto_pad, dilations, group, kernel_shape, pads, strides) - -def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: - A_int = A.int() - a_zero_point - B_int = B.int() - b_zero_point - return Tensor.matmul(A_int, B_int, acc_dtype=dtypes.int32) - -# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py -def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"): - try: import PIL.Image - except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e - img = PIL.Image.open(io.BytesIO(encoded_stream)) - if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1] - if pixel_format == "RGB": return Tensor(np.array(img)) - if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1) - raise ValueError(f"pixel_format={pixel_format!r} is not supported.") - -def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): - N, _, *spatial_dims = size - def generate_grid(steps): - return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) - grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) - base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) - base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) - return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - -# **************** com.microsoft Ops **************** + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var + invstd = (input_var + epsilon).rsqrt() + return X.batchnorm(scale, B, input_mean, invstd) +def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05): + axis = tuple(range(2, x.ndim)) + mean = x.mean(axis=axis, keepdim=True) + invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt() + return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1])) +def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1): + assert stash_type == 1, "only float32 is supported" + axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)) + mean = x.mean(axis=axes, keepdim=True) + return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt() +def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05): + return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape) +def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): + return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): x = x + skip + bias return x.layernorm(eps=epsilon) * gamma + beta, None, None, x - -def FastGelu(x:Tensor, bias:Tensor|None=None): - # this is tanh approximated - return (x + bias).gelu() if bias is not None else x.gelu() - def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): @@ -513,6 +358,45 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embeddin out = embedding_sum.layernorm(eps=epsilon) * gamma + beta return out, None, embedding_sum +def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): + # Scalar or Rank 1 tensor containing exactly one element + depth = int(depth[0] if isinstance(depth, list) else depth) + indices = (indices < 0).where(indices+depth, indices) + return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) + +def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"): + return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize) +def SpaceToDepth(X:Tensor, blocksize:int): + return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize) + +# Reimplemented here because you need legacy RNG for passing ONNX tests. +def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None): + if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's. + mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device) + return data * mask * (1/(1.0 - ratio)), mask +# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx +def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test) +Dropout = {6:Dropout_6, 7:Dropout_7} + +def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0): + pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1) + return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta) + +def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + return x.nll_loss(target, weight, ignore_index, reduction) +def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"): + log_probs = scores.log_softmax(1) + return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs + +def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0): + N, _, *spatial_dims = size + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) + def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, @@ -552,37 +436,132 @@ def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past: out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) return out, present if past is not None else out +# ***** Indexing Ops ***** +def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] + +def Gather(x:Tensor, indices:Tensor, axis:int=0): + if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices + x_sh = list(x.shape) + ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:] + if indices.ndim > 1: indices = indices.flatten() + indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)] + args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore + return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape) + # NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot + return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])] +def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated + +def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0): + if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))] + x_shape, i_shape = x.shape, indices.shape + b = math.prod(x.shape[dim] for dim in range(batch_dims)) + # NOTE: each batched dim of both input and indices are equal + x = x.reshape(b, *x.shape[batch_dims:]) + indices = indices.reshape(b, *indices.shape[batch_dims:]) + b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1]) + ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))] + return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:]) +def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'): + assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):] + x = x.contiguous() + for index, u in zip(indices.split(1, 0), updates.split(1, 0)): + i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1)) + u = u.squeeze(0) + if reduction == "none": x[i] = u + elif reduction == "add": x[i] += u + elif reduction == "mul": x[i] *= u + else: raise NotImplementedError("reduction doesn't support max or min") + return x + +def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction)) +def GatherElements(x:Tensor, indices:Tensor, axis:int): + indices = (indices < 0).where(x.shape[axis], 0) + indices + return x.gather(axis, indices) + +def Compress(inp:Tensor, condition:list[bool], axis:int|None=None): + if axis is None: + inp = inp.flatten() + axis = 0 + if axis < 0: axis += inp.ndim + con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor + return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))] + +# ***** Quantization Ops ***** +def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype) + +def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0): + if axis < 0: axis += x.ndim + if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape) + if block_size == 0: + shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim)) + return scale.reshape(shape), zero_point.reshape(shape) + return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis) + +def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1): + out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8 + y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size) + return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous() + +def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0): + x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size) + return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype) + +def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts): + adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)] + return op(*adjusted_inputs, **opts) + +def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in quantized int + out = _op_integer(op, inputs, zero_points, **opts) + assert dtypes.is_int(out.dtype), "quantized op should've done math in int" + out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts): + # op execution is done in float32 + dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)] + out = op(*dequantized_inputs, **opts) + assert dtypes.is_float(out.dtype), "op should've done math in float" + out_quantized = (out / out_scale).round() + out_zero_point + return _clamp_cast(out_quantized, out_zero_point.dtype) + +def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point: Tensor|int, B:Tensor|None=None, **opts): + return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts}) + +def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor, + y_zero_point:Tensor|int) -> Tensor: + return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point) + def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor): - a = a.int() - a_zero_point - b = b.int() - b_zero_point - c = (a * a_scale + b * b_scale) - return _quantize_linear(c, c_scale, c_zero_point) + return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point) def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int): - assert channels_last in {0, 1} - if channels_last == 1: X = X.permute(0, 2, 3, 1) - X = (X.int() - x_zero_point) * x_scale - y = GlobalAveragePool(X) - return _quantize_linear(y, y_scale, y_zero_point) + assert channels_last == 0, "unsure what this does" + return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point) -# **************** ai.onnx.preview.training Ops **************** +def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor: + return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts}) + +def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor: + return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point]) + +# ***** Training Ops ***** # NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested # NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code - -from tinygrad.nn.optim import Adam as TinyAdam -from tinygrad.nn.optim import SGD - -def onnx_training(input_group_size): - def _decorator(func): - def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): +def _onnx_training(input_group_size): + def __decorator(func): + def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs): R = R.detach() groups = len(inputs) // input_group_size ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))] return tuple(flatten(zip(*ret))) - return __wrapper - return _decorator + return ___wrapper + return __decorator -@onnx_training(3) +@_onnx_training(3) def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0): X, G, H = (i.detach() for i in inputs) grad = norm_coefficient * X + G @@ -592,9 +571,10 @@ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:flo X.assign(X.detach() - r * up) return [X, H] -@onnx_training(4) +@_onnx_training(4) def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0, - norm_coefficient_post:float=0.0): + norm_coefficient_post:float=0.0): + from tinygrad.nn.optim import Adam as TinyAdam X, G, V, H = inputs G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches X.grad = norm_coefficient * X.detach() + G @@ -610,8 +590,9 @@ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, eps X = (1 - norm_coefficient_post) * X return [X, V, H] -@onnx_training(3) +@_onnx_training(3) def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float): + from tinygrad.nn.optim import SGD X, G, V = inputs G, V = G.detach(), V.detach() X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1) @@ -620,6 +601,6 @@ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, opt.step() return [X, V] -def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **__): +def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_): intermediate_tensors[y].backward() return tuple([t.grad for t in inputs]) diff --git a/mkdocs.yml b/mkdocs.yml index 291998dac5..38419a5708 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,6 @@ nav: - Runtime: runtime.md - Developer: - Intro: developer/developer.md - - Function (autodiff): developer/function.md - UOp: developer/uop.md - Runtime: - developer/runtime.md diff --git a/test/external/external_benchmark_multitensor_allreduce.py b/test/external/external_benchmark_multitensor_allreduce.py index 6867bcaa8f..92fb5ead3c 100644 --- a/test/external/external_benchmark_multitensor_allreduce.py +++ b/test/external/external_benchmark_multitensor_allreduce.py @@ -44,8 +44,8 @@ def main(): else: sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.") - (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True) - (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False) + (ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus) + (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus) print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s") print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s") diff --git a/test/external/external_fuzz_ampt.py b/test/external/external_fuzz_ampt.py index 72f8b3801a..f3d4c88e86 100644 --- a/test/external/external_fuzz_ampt.py +++ b/test/external/external_fuzz_ampt.py @@ -25,7 +25,7 @@ class AMPTFuzzer: _vaddr = va.va_addr + _offset for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) self.d.vram[pte['paddr']] = pattern # Mark this page assert pte['valid'] == 1 @@ -41,7 +41,7 @@ class AMPTFuzzer: frags_l = list(ctx.next(contig_range)) for f_offset, f_pt, f_pte_idx, f_n_ptes, f_pte_covers in frags_l: for j in range(f_n_ptes): - f_pte = helper_read_entry_components(f_pt.get_entry(f_pte_idx + j)) + f_pte = helper_read_entry_components(f_pt.entries[f_pte_idx + j]) assert f_pte['valid'] == 1 assert f_pte['paddr'] == start_paddr+f_offset+j*f_pte_covers, f"paddr {f_pte['paddr']:#x} not {start_paddr+f_offset+j*f_pte_covers:#x}" @@ -53,7 +53,7 @@ class AMPTFuzzer: def verify_memory(self, pages, pattern: int) -> bool: for _offset, _pt, _pte_idx, _n_ptes, _pte_covers in pages: for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) if self.d.vram[pte['paddr']] != pattern: return False if pte['valid'] == 0: return False diff --git a/test/external/external_test_am.py b/test/external/external_test_am.py index 985554623e..5d918156ca 100644 --- a/test/external/external_test_am.py +++ b/test/external/external_test_am.py @@ -3,7 +3,9 @@ from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraver from tinygrad.helpers import mv_address class FakeGMC: - def __init__(self): self.vm_base = 0x0 + def __init__(self): + self.vm_base = 0x0 + self.address_space_mask = (1 << 44) - 1 def flush_tlb(self, *args, **kwargs): pass class FakePCIRegion: @@ -14,7 +16,7 @@ class FakePCIDev: class FakeAM: def __init__(self): - self.is_booting = True + self.is_booting, self.smi_dev = True, False self.pcidev = FakePCIDev() self.vram = memoryview(bytearray(4 << 30)) self.gmc = FakeGMC() @@ -72,7 +74,7 @@ class TestAMPageTable(unittest.TestCase): for tup in results: _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == va + _offset + i * _pte_covers, f"Expected paddr {pte['paddr']:#x} to be {va + _offset + i * _pte_covers:#x}" assert pte['valid'] == 1 @@ -81,7 +83,7 @@ class TestAMPageTable(unittest.TestCase): for tup in results: _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == 0 assert pte['valid'] == 0 @@ -113,7 +115,7 @@ class TestAMPageTable(unittest.TestCase): for tup in ctx.next(0x100000): _offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup for i in range(_n_ptes): - pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i)) + pte = helper_read_entry_components(_pt.entries[_pte_idx + i]) assert pte['paddr'] == 0xdead0000 + _offset + i * _pte_covers, f"paddr {pte['paddr']:#x} not {0xdead0000 + _offset + i * _pte_covers:#x}" assert pte['valid'] == 1 diff --git a/test/external/process_replay/process_replay.py b/test/external/process_replay/process_replay.py index 1bc597c6b9..1cfef40b9e 100755 --- a/test/external/process_replay/process_replay.py +++ b/test/external/process_replay/process_replay.py @@ -2,7 +2,7 @@ # compare kernels created by HEAD against master from collections import defaultdict import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings -from typing import Callable, List, Tuple, Union, cast +from typing import Callable, cast from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm from tinygrad.engine.schedule import ScheduleContext, schedule_uop from tinygrad.codegen.kernel import Kernel, Opt @@ -33,15 +33,15 @@ class ProcessReplayWarning(Warning): pass def recreate_sched(ast:UOp) -> UOp: # NOTE: process replay isn't meant to actually schedule anything return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast -def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str: +def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str: k = Kernel(ast, opts=opts) for opt in applied_opts: k.apply_opt(opt) # NOTE: replay with the captured renderer, not the one in master - return k.opts.render(name, cast(List,k.to_program().uops)) + return k.opts.render(name, cast(list,k.to_program().uops)) # *** diff a "good" recreation against the generated version -def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]: +def diff(offset:int, name:str, fxn:Callable) -> tuple[int, int]|bool: if early_stop.is_set(): return True conn = db_connection() cur = conn.cursor() @@ -95,7 +95,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None: cur.close() with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool: inputs = list(range(0, row_count, PAGE_SIZE)) - ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs))) + ret: list[tuple[int, int]|bool] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs))) pool.close() pool.join() pool.terminate() diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index a213e5ec88..7aa3253a07 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase): loss.backward() optimizer.step() - helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63) + helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65) @unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow") def test_train_cifar(self): diff --git a/test/models/test_train.py b/test/models/test_train.py index 6020a8e777..605e6f6de1 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -40,6 +40,7 @@ class TestTrain(unittest.TestCase): check_gc() @unittest.skipIf(CI, "slow") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_efficientnet(self): model = EfficientNet(0) X = np.zeros((BS,3,224,224), dtype=np.float32) @@ -56,6 +57,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal") def test_transformer(self): # this should be small GPT-2, but the param count is wrong # (real ff_dim is 768*4) diff --git a/test/test_arange.py b/test/test_arange.py index 07512ae1b6..2229ac847f 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase): # llama3 is 128256 vocab_size, embed_size = (10, 3) if CI else (32000, 4096) emb = nn.Embedding(vocab_size, embed_size) - emb_w = emb.weight.numpy() + # TODO: why is a new realize needed here + emb_w = emb.weight.realize().numpy() x = Tensor([1,2,3,4]) with Context(NOOPT=noopt, FUSE_ARANGE=1): GlobalCounters.reset() z = emb(x).realize() self.assertLessEqual(GlobalCounters.global_ops, op_limit) - self.assertEqual(GlobalCounters.kernel_count, 3) + self.assertEqual(GlobalCounters.kernel_count, 2) if getenv("CHECK", 1): import torch with torch.no_grad(): diff --git a/test/test_gc.py b/test/test_gc.py index 010f59039f..cf90dc6201 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -8,9 +8,11 @@ from tinygrad.ops import UOp from tinygrad.tensor import Tensor def tensors_allocated(): + gc.collect() return sum([isinstance(x, Tensor) for x in gc.get_objects()]) def bufs_allocated(): + gc.collect() return sum([isinstance(x, Buffer) for x in gc.get_objects()]) class TestGC(unittest.TestCase): @@ -31,7 +33,7 @@ class TestGC(unittest.TestCase): base = tensors_allocated() a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated()-base == 5) + assert (tensors_allocated()-base == 4) (a*b).mean().backward() assert (tensors_allocated()-base == 6) del b diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index 62fcb4a443..a7fd2ab94f 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -134,5 +134,39 @@ class TestImageDType(unittest.TestCase): print(lst) assert not np.any(np.isnan(lst)) +@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU") +class TestImageRealization(unittest.TestCase): + def test_image_dtype_expand(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize() + self.assertEqual(it_expanded.dtype, dtypes.float32) + + def test_image_dtype_expand_and_back(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)) + it2 = it_expanded.sum(3).realize() + self.assertEqual(it2.dtype, dtypes.imagef((9,27,4))) + + def test_image_alu_children(self): + data = Tensor.randn(9*27*4).realize() + it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize() + self.assertEqual(it.dtype, dtypes.imagef((9,27,4))) + it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous() + alu1 = it_expanded+1 + alu2 = it_expanded.sum(3) + it_expanded.realize() + # NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image + self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4))) + alu1.realize() + self.assertEqual(alu1.dtype, dtypes.float32) + # alu2 is back in image because it fits the dtype again + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + alu2.realize() + self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4))) + if __name__ == '__main__': unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index da2995d37d..c0bdc2c6f8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase): # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 - @unittest.expectedFailure # requires contiguous folding def test_masked_upcast_wino_full(self): with Context(WINO=1): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 80ca7fd6e8..8570dac761 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] - helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"]) + helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02) # llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local") diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b34baced75..a423d79731 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,11 +1,11 @@ import unittest, functools, random from typing import List -from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import Ops +from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable +from tinygrad.ops import Ops, UOp from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule -from tinygrad.multi import all_reduce, MultiLazyBuffer +from tinygrad.multi import all_reduce import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.device import is_dtype_supported @@ -30,7 +30,7 @@ N = 128 def _test_allreduce(t:Tensor): aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize() ts = t.shard(devices_4, 0).realize() - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0)) b.realize() return aa, b @@ -39,7 +39,7 @@ class TestMultiTensor(unittest.TestCase): def test_to(self): X = Tensor.ones(256).contiguous().realize() X.to_(devices_2) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (256,) (X + X).realize() @@ -52,7 +52,7 @@ class TestMultiTensor(unittest.TestCase): def test_shard(self): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) - for lb in X.lazydata.lbs: + for lb in X.lazydata.src: assert lb.shape == (128,) (X + X).realize() @@ -218,9 +218,9 @@ class TestMultiTensor(unittest.TestCase): shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))]) t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0) with Context(RING=0): - a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) with Context(RING=2): - b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0)) + b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0)) diff = a - b mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy() max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy() @@ -344,9 +344,7 @@ class TestMultiTensor(unittest.TestCase): # NOTE: this is failing on LLVM CI, no idea why. Works locally. @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow") def test_data_parallel_resnet(self): - import sys, pathlib - sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) - from resnet import ResNet18 + from extra.models.resnet import ResNet18 fake_image = Tensor.rand((2, 3, 224//8, 224//8)) fake_image_sharded = fake_image.shard(devices_2, axis=0) @@ -356,16 +354,14 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(m): p.shard_(devices_2).realize() GlobalCounters.reset() shard_output = m(fake_image_sharded).log_softmax().realize() - assert shard_output.lazydata.lbs[0].shape == (1, 1000) - assert shard_output.lazydata.lbs[1].shape == (1, 1000) + assert shard_output.lazydata.src[0].shape == (1, 1000) + assert shard_output.lazydata.src[1].shape == (1, 1000) shard_output_np = shard_output.numpy() np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6) @unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM") def test_data_parallel_resnet_train_step(self): - import sys, pathlib - sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix()) - from resnet import ResNet18 + from extra.models.resnet import ResNet18 from tinygrad.nn.optim import LARS fake_image = Tensor.rand((2, 3, 224//8, 224//8)) @@ -386,12 +382,35 @@ class TestMultiTensor(unittest.TestCase): GlobalCounters.reset() optimizer.zero_grad() shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1) - assert shard_output.lazydata.axis is None shard_output.backward() shard_grad = m.conv1.weight.grad.numpy() # sometimes there is zeros in these grads... why? np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5) + def test_assign_kv_cache_multi(self): + bsz, max_context = 2, 8 + + class Attn: + @TinyJit + def __call__(self, xk:Tensor, start_pos:UOp): + seqlen = xk.shape[1] + if not hasattr(self, "cache_k"): + self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize() + keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk + self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize() + + attn = Attn() + xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous() + attn(xk, 0) + for i in range(3,6): + # copied from LLaMA + start_pos = Variable("start_pos", 1, max_context).bind(i) + xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous() + attn(xk, start_pos) + + out = attn.cache_k.flatten().numpy() + np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.]) + def test_multi_tensor_jit_param(self): @TinyJit def jf(a, b) -> Tensor: @@ -532,13 +551,13 @@ class TestMultiTensor(unittest.TestCase): t4 = t2.reshape((26, 105,)) for t in [t0, t1, t2, t3, t4]: - assert t.lazydata.axis == 1 np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten()) + assert t.lazydata.axis == 1 # test shape-one axis t5 = t4.reshape((26, 1, 105)) - assert t5.lazydata.axis == 2 np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten()) + assert t5.lazydata.axis == 2 # test split and rejoin to the right and reshape to the left t5 = t0.reshape((2, 13, 3, 5, 7)) @@ -553,7 +572,7 @@ class TestMultiTensor(unittest.TestCase): # test no left join with self.assertRaises((AssertionError, ValueError)): - t0.reshape((26*15,7)) + t0.reshape((26*15,7)).schedule() @unittest.skip("no longer supports uneven shard") def test_reshape_on_axis_uneven(self): @@ -588,6 +607,7 @@ class TestMultiTensor(unittest.TestCase): with self.assertRaises(AssertionError): # don't allow assigns that change axes t_none.assign(t_zero) + t_none.schedule() def test_init_rand_with_multiple_devices_fail(self): # init rand with multi device is not allowed @@ -627,6 +647,16 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) + def test_rand_like_from_alu(self): + # TODO: fix this, which will also fix multi device dropout + a = Tensor.ones(4, 4).shard(devices_2, axis=0) + with self.assertRaises(AssertionError): + (a + a).rand_like() + + b = Tensor.empty(4, 4).shard(devices_2, axis=None) + with self.assertRaises(AssertionError): + (a + b).rand_like() + @unittest.skip("no longer supports uneven shard") def test_rand_like_uneven_shard(self): t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1) @@ -635,7 +665,7 @@ class TestMultiTensor(unittest.TestCase): self.assertEqual(t.device, t2.device) self.assertEqual(t.dtype, t2.dtype) self.assertEqual(t.lazydata.axis, t2.lazydata.axis) - assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs)) + assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src)) def test_rand_like_none_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) @@ -718,7 +748,7 @@ class TestMultiTensor(unittest.TestCase): devices = (d0, d1, d2, d3) t = Tensor.zeros(16, 16).contiguous() t.shard_(devices, axis=0).realize() - assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs]) + assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src]) @unittest.skip("this is unreliable on OSX") def test_clone(self): @@ -774,25 +804,31 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): with self.assertRaises(AssertionError): # sharded axis shrink on non-device boundry is not allowed a = t.shrink(((0, 3), (0, 8))) + a.schedule() with self.assertRaises(AssertionError): # cannot shrink sharded and non-sharded axis at the same time a = t.shrink(((0, 2), (2, 4))) + a.schedule() a = t.shrink(((0, 2), (0, 8))) + a.schedule() assert a.shape == (2, 8) - assert a.lazydata.real == [True, False, False, False] + assert a.lazydata.real == (True, False, False, False) with self.assertRaises(AssertionError): # cannot pad sharded and non-sharded axis at the same time p = a.pad(((0, 6), (0, 1))) + p.schedule() with self.assertRaises(AssertionError): # can only pad to whole axis p = a.pad(((1, 5), (0, 0))) + p.schedule() p = a.pad(((0, 6), (0, 0))) + p.schedule() assert p.shape == (8, 8) - assert p.lazydata.real == [True, True, True, True] + assert p.lazydata.real == (True, True, True, True) @given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16])) def test_ops(self, dtype): @@ -804,8 +840,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): a = t.shrink(((0+2*i,2+2*i),None)) b = Tensor(t.numpy()[0+2*i:2+2*i]) assert a.shape == b.shape == (2, 8) - assert a.lazydata.real == [i==j for j in range(4)] np.testing.assert_allclose(a.numpy(), b.numpy()) + assert a.lazydata.real == tuple(i==j for j in range(4)) # cast np.testing.assert_allclose(a.float().numpy(), b.float().numpy()) @@ -859,24 +895,27 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): np.testing.assert_equal((a+a).numpy(), na+na) np.testing.assert_equal((b+b).numpy(), nb+nb) + @unittest.skip("why didn't this work?") def test_add_two_partitions(self): t = Tensor.arange(64).reshape(8, 8).contiguous().realize() t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0) a = t.shrink(((2, 4), None)) b = t.shrink(((6, 8), None)) - self.assertEqual(a.lazydata.real, [False, True, False, False]) - self.assertEqual(b.lazydata.real, [False, False, False, True]) na = t.numpy()[2:4] nb = t.numpy()[6:8] np.testing.assert_equal(a.numpy(), na) np.testing.assert_equal(b.numpy(), nb) + self.assertEqual(a.lazydata.real, (False, True, False, False)) + self.assertEqual(b.lazydata.real, (False, False, False, True)) with self.assertRaises(AssertionError): # cannot add directly c = a + b + c.schedule() c = a.pad(((2, 4), None)) + b.pad(((6, 0), None)) - self.assertEqual(c.lazydata.real, [True, True, True, True]) + c.realize() + self.assertEqual(c.lazydata.real, (True, True, True, True)) expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb]) np.testing.assert_equal(c.numpy(), expected) @@ -937,8 +976,9 @@ class TestBatchNorm(unittest.TestCase): def __call__(self, x:Tensor): bn_ts = [] - for bound, bn in zip(x.lazydata.bounds, self.bns): - xi = x.shrink((bound, None, None, None)) + each = x.shape[0]//len(self.bns) + for i, bn in enumerate(self.bns): + xi = x.shrink(((each*(i), each*(i+1)), None, None, None)) bni = bn(xi) bn_ts.append(bni) return bn_ts[0].cat(*bn_ts[1:]) diff --git a/test/test_ops.py b/test/test_ops.py index d342a71df4..020f168b8d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -352,12 +352,15 @@ class TestOps(unittest.TestCase): def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y) def test_cmp_ne_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 != t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt != 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) @@ -365,12 +368,15 @@ class TestOps(unittest.TestCase): np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5) def test_cmp_lt_backwards(self): + # new grad zeroes these out + """ t1 = torch.ones(4, requires_grad=True) t2 = torch.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (t1 < t2).sum().backward) tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + """ tt = Tensor.randn(4, requires_grad=True) (tt*(tt < 0)).sum().backward() t = torch.tensor(tt.numpy(), requires_grad=True) diff --git a/test/test_schedule.py b/test/test_schedule.py index ccf423d0a5..f79f5bd378 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views -from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same +from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp from tinygrad.codegen.kernel import verify_ast -from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym +from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2) @track_rewrites(named=True) -def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext()) +def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {}) class TestSchedule(unittest.TestCase): def test_basic_binop_fusion(self): @@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]: + for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) @@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase): check_schedule(out, 2) # multireduce spec + @unittest.skip("these two Tensors are the same") def test_example_matmul(self): x = Tensor.eye(64, requires_grad=True) y = Tensor.eye(64, requires_grad=True) @@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase): run_schedule(check_schedule(out, 2)) np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_contig(self): + x = Tensor.eye(64, requires_grad=True).contiguous().realize() + y = Tensor.eye(64, requires_grad=True).contiguous().realize() + z = y.matmul(x).sum() + z.backward() + out = x.grad.contiguous() + run_schedule(check_schedule(out, 2)) + np.testing.assert_allclose(out.numpy(), np.ones((64,64))) + def test_example_matmul_same(self): x = Tensor.eye(64, requires_grad=True) z = x.matmul(x).sum() @@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 13) + check_schedule(opt.schedule_step(), 14) def test_sgd_conv_fuse(self): with Tensor.train(): @@ -1060,7 +1070,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters(c1)) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 5) + check_schedule(opt.schedule_step(), 3) def test_sgd_2convs_fuse(self): with Tensor.train(): @@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 8) + check_schedule(opt.schedule_step(), 7) def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): @@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 10) + check_schedule(opt.schedule_step(), 9) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 18) + check_schedule(opt.schedule_step(), 17) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): @@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase): opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15) + with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): @@ -1844,7 +1854,7 @@ def run_tensor_ast(r:Tensor): sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink() sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output]) sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right) - si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ()) + si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ()) run_schedule([si]) return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist() @@ -1940,19 +1950,19 @@ class TestSwizzle(unittest.TestCase): ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) - @unittest.expectedFailure + @unittest.skip("this swizzle can't be decided after the ADD") def test_swizzle_failure_permute(self): sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(Ops.WHERE, dtypes.float, arg=None, src=( x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( @@ -1971,13 +1981,13 @@ class TestSwizzle(unittest.TestCase): UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), x15,)), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), x10,)), UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=( UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( - UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)) ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) @@ -2269,6 +2279,54 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.empty(4).lazydata check_schedule(a.clone(), 1, filter_sink=False) + # NOTE: moving copy before view might change this + def test_shrink_copy(self): + a = Tensor.arange(4) + view = a.shrink(((0, 2),)) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 2) + self.assertListEqual(b.tolist(), [0, 1]) + + def test_expanded_copy(self): + a = Tensor.arange(2) + view = a.reshape(2, 1).expand(2, 2) + b = view.clone() + run_schedule(check_schedule(b, 2, filter_sink=False)) + self.assertEqual(b.lazydata.base.buffer.size, 2) + self.assertEqual(b.lazydata.size, 4) + self.assertListEqual(b.tolist(), [[0, 0], [1, 1]]) + + def test_permuted_copy(self): + a = Tensor.arange(4) + b = a.reshape(2, 2).permute(1, 0) + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_on_disk(self): + with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}") + b = a.reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + def test_permute_after_shrink(self): + a = Tensor.arange(5) + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + + # NOTE: disk permute must come after COPY + # TODO: this is wrong because of the permute + @unittest.expectedFailure + def test_permute_after_shrink_on_disk(self): + with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer()) + a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}") + b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG") + b.realize() + self.assertListEqual(b.tolist(), [[0, 2], [1, 3]]) + class TestTensorUOpSpec(unittest.TestCase): def test_const_must_be_unmasked(self): a = Tensor.ones((4, 4)).pad((2, 2)) @@ -2377,6 +2435,17 @@ class TestContiguous(unittest.TestCase): b = a.expand((4, 4)).contiguous().contiguous() check_schedule(b, 1) + def test_view_does_not_realize(self): + a = Tensor.empty(4) + b = a.expand((4, 4)) + check_schedule(b, 0) + self.assertEqual(b.lazydata.base.buffer.size, 4) + + def test_contiguous_view_realizes(self): + a = Tensor.empty(4) + b = a.expand((4, 4)).contiguous() + check_schedule(b, 1) + self.assertEqual(b.lazydata.base.buffer.size, 16) class TestUOpBecome(unittest.TestCase): # the simplest case, if we create a new BUFFER for this UOp diff --git a/test/test_tensor.py b/test/test_tensor.py index 0e9c20d022..b1b90ea44b 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase): def test_clone(self): a = Tensor.rand(16, 16).realize() - self.assertIsNot(a.lazydata, a.clone().lazydata) - np.testing.assert_allclose(a.numpy(), a.clone().numpy()) + b = a.clone() + np.testing.assert_allclose(a.numpy(), b.numpy()) + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) a = Tensor.rand(16, 16).mul(5.0).add(5.0) - self.assertIsNot(a.lazydata, a.clone().lazydata) - np.testing.assert_allclose(a.numpy(), a.clone().numpy()) + b = a.clone() + np.testing.assert_allclose(a.numpy(), b.numpy()) + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_clone_with_shrink(self): - a = Tensor.empty(16, 16) - self.assertIsNot(a.lazydata, a.clone().lazydata) + a = Tensor.rand(16, 16) + b = a.shrink(((2, 10), None)).clone() + b.realize() + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) - b = a.shrink(((2, 10), None)) - self.assertIsNot(b.lazydata, b.clone().lazydata) + def test_clone_with_shrink_realized(self): + a = Tensor.rand(16, 16).realize() + b = a.shrink(((2, 10), None)).clone() + b.realize() + self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer) def test_clone_with_grad(self): a = Tensor.rand(16, 16, requires_grad=True) @@ -763,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) def test_complex_backward(self): - x = Tensor.rand(3, requires_grad=True) - y = Tensor.rand(3, requires_grad=True) + x = Tensor.rand(3, requires_grad=True).realize() + y = Tensor.rand(3, requires_grad=True).realize() out = (x.relu() * y.sigmoid()).sum() self.assertEqual(out.lazydata.metadata.name, "sum") out.backward() diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index a9a41eace0..f7a99982fa 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase): def _test_one_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) - gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x] + gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x] gf = jax.grad(f if jf is None else jf) for val in [-5., -2.0, 0.0, 2.0, 5.]: @@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase): def _test_two_input_function(self, f:Callable, jf:Callable|None=None): x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float) y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float) - grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y]) + grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y])) gx, gy = grads[x], grads[y] gf = jax.grad(f if jf is None else jf, argnums=(0, 1)) diff --git a/test/test_viz.py b/test/unit/test_viz.py similarity index 59% rename from test/test_viz.py rename to test/unit/test_viz.py index 2723a411c4..ff84813d1e 100644 --- a/test/test_viz.py +++ b/test/unit/test_viz.py @@ -1,118 +1,104 @@ -from typing import Dict, List, Optional import unittest, decimal, json from tinygrad.dtype import dtypes -from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic -from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys -from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic +from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry -from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto +from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto -@track_rewrites(named=True) -def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs) - -def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]: - rewrite(sink, pm, **kwargs) - assert len(contexts) == 1 - assert len(contexts[0]) == 1 - k = get_metadata(keys, contexts)[0][0] - g = get_details(*k) - return g.uops[1:] +# NOTE: VIZ tests always use the tracked PatternMatcher instance +symbolic = TrackedPatternMatcher(symbolic.patterns) class TestViz(unittest.TestCase): def setUp(self): + # clear the global context contexts.clear() keys.clear() + _name_cnt.clear() self.tms = TRACK_MATCH_STATS.value TRACK_MATCH_STATS.value = 2 def tearDown(self): TRACK_MATCH_STATS.value = self.tms def test_viz_simple(self): - pm = PatternMatcher([ - (UPat.var("x")*1, lambda x:x), - ]) - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - uops = helper_test_viz(a*1, pm) - self.assertEqual(len(uops), 1) - self.assertEqual(uops[0], a) + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(sink): return graph_rewrite(sink, symbolic) + test(a*1) + ret = get_metadata(keys, contexts) + self.assertEqual(len(ret), 1) + key, val = ret[0] + self.assertEqual(key, "test_1") + self.assertEqual(val[0]["match_count"], 1) - def test_rewrite_twice(self): - pm = PatternMatcher([ - (UPat.var("x")+UPat.var("x"), lambda x:x*2), - (UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))), - ]) - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - uops = helper_test_viz(a+a, pm) - self.assertEqual(len(uops), 2) - self.assertEqual(uops[0], a*2) - self.assertEqual(uops[1], graph_rewrite(a+a, pm)) + def test_track_two_rewrites(self): + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(sink): return graph_rewrite(sink, symbolic) + test((a+a)*1) + ret = get_metadata(keys, contexts) + key, val = ret[0] + self.assertEqual(len(ret), 1) # one context + self.assertEqual(len(val), 1) # one graph_rewrite call in context + self.assertEqual(key, "test_1") + self.assertEqual(val[0]["match_count"], 2) # two upats applied - def test_rewrite_with_ctx(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), ShapeTracker.from_shape((1, 1)).to_uop())) - b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), ShapeTracker.from_shape((1, 1)).to_uop())) - def store_load(ctx:Dict[UOp, None], glbl, st) -> Optional[UOp]: - if glbl in ctx: return None - ctx[glbl] = None - return UOp.store(glbl, ShapeTracker.from_shape(st.shape).to_uop()) - pm = PatternMatcher([ - (UPat.load(UPat(Ops.DEFINE_GLOBAL, name="glbl"), UPat.var("st")), store_load), - ]) - uops = helper_test_viz(a+b, pm, ctx={}) - self.assertEqual(len(uops), 2) - self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {})) + def test_track_multiple_calls_one_ctx(self): + a = UOp.variable("a", 0, 10) + @track_rewrites(named=True) + def test(a, b): + a = graph_rewrite(a, symbolic) + b = graph_rewrite(b, symbolic) + test(a*1, a*5) + ret = get_metadata(keys, contexts) + key, val = ret[0] + self.assertEqual(len(ret), 1) # one context + self.assertEqual(len(val), 2) # two graph_rewrite calls in context + self.assertEqual(key, "test_1") + self.assertEqual(val[0]["match_count"], 1) # one rewrite for a*0 + self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5 def test_track_rewrites(self): - simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites(named=True) - def do_rewrite(x:UOp): return graph_rewrite(x, simple) - ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) - do_rewrite(ld*1) - do_rewrite(ld*2) + def do_rewrite(x:UOp): return graph_rewrite(x, symbolic) + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 4) + do_rewrite(a*1) + do_rewrite(a*b) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 2) - key, _, m = ret[0][0] + key, m = ret[0] self.assertEqual(key, "do_rewrite_1") - self.assertEqual(len(m.upats), 1) - key, _, m = ret[1][0] + self.assertEqual(m[0]["match_count"], 1) + key, m = ret[1] self.assertEqual(key, "do_rewrite_2") - self.assertEqual(len(m.upats), 0) + self.assertEqual(m[0]["match_count"], 0) def test_track_rewrites_with_exception(self): - simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)]) @track_rewrites() def do_rewrite(x:UOp): - x = graph_rewrite(x, simple) # NOTE: viz tracks this + x = graph_rewrite(x, symbolic) # NOTE: viz tracks this raise Exception("test") - ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0))) - with self.assertRaises(Exception): do_rewrite(ld*1) + a = UOp.variable("a", 0, 10) + with self.assertRaises(Exception): do_rewrite(a*1) ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - def test_fold_const(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - graph = uop_to_json(a) - assert not any(v[0].startswith("CONST") for v in graph.values()) - assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 + # NOTE: CONST UOps do not get nodes in the graph + def test_dont_create_const_nodes(self): + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 4) + self.assertEqual(len(uop_to_json(a*1)), 2) + self.assertEqual(len(uop_to_json(a*b)), 3) @unittest.skip("TODO: bring this back with better testing") def test_bottom_up_rewrite(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - n1 = a.sin() - uop = n1.sin() - pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) - ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True) - self.assertEqual(len(ret), 2) - self.assertIs(ret[0], a.sin().sqrt()) # first rewrite - self.assertIs(ret[1], a.sqrt().sqrt()) # second one - - def test_top_down_rewrite(self): - a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) - n1 = a.sin() - uop = n1.sin() - pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) - # if it wasn't bottom_up, it's rewritten once - ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False) + a = UOp.variable("a", 0, 10) + b = UOp.variable("b", 0, 10) + c = UOp.variable("c", 0, 10) + UOp.substitute(a+b, {a+b:c}) + ret = get_metadata(keys, contexts) self.assertEqual(len(ret), 1) - self.assertIs(ret[0], a.sqrt().sin()) # only rewrite + _, _, vals = ret[0][0] + self.assertEqual(len(vals.upats), 1) # NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ def test_rewrite_without_context(self): @@ -211,6 +197,5 @@ class TextVizProfiler(unittest.TestCase): self.assertEqual(j['traceEvents'][7]['dur'], 4) self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid']) - if __name__ == "__main__": unittest.main() diff --git a/tinygrad/device.py b/tinygrad/device.py index 2a20992f80..04182a3c33 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -3,14 +3,14 @@ from dataclasses import dataclass, replace from collections import defaultdict from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time -from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ - cpu_time_execution + cpu_time_execution, colored, Context from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes from tinygrad.renderer import Renderer # **************** Device **************** +ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"] class _Device: def __init__(self) -> None: self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] @@ -25,7 +25,7 @@ class _Device: cpn = multiprocessing.current_process().name assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}" x = ix.split(":")[0].upper() - ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \ + ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \ if (cname.lower() == x.lower() + "device")][0](ix) if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") self._opened_devices.add(ix) @@ -33,7 +33,7 @@ class _Device: @property def default(self) -> Compiled: return self[self.DEFAULT] def get_available_devices(self) -> Iterator[str]: - for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]: + for device in ALL_DEVICES: with contextlib.suppress(Exception): yield self[device].device @functools.cached_property def DEFAULT(self) -> str: @@ -220,9 +220,10 @@ MAP_JIT = 0x0800 # CPUProgram is a jit/shellcode program that can be just mmapped and jumped to class CPUProgram: - helper_handle = ctypes.CDLL(ctypes.util.find_library('System') if OSX else 'libgcc_s.so.1') - + helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32' if sys.platform == "win32" else 'gcc_s')) def __init__(self, name:str, lib:bytes): + assert sys.platform != "win32", "clang is not supported for windows yet" + from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE # On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/ # MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np) self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC) @@ -314,3 +315,18 @@ if PROFILE: from tinygrad.ops import launch_viz launch_viz("PROFILE", fn) + +if __name__ == "__main__": + for device in ALL_DEVICES: + try: + _ = Device[device].device + try: + from tinygrad import Tensor + with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist() + if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]") + result = colored("PASS", "green") + except Exception as e: + result = f"{colored('FAIL', 'yellow')} {e}" + except Exception as e: + result = f"{colored('FAIL', 'red')} {e}" + print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}") diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index f718ef311e..afb9d726a4 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap from tinygrad.device import Buffer, Compiled, Device from tinygrad.dtype import DType -from tinygrad.ops import UOp, Variable, sym_infer +from tinygrad.ops import UOp, Variable, sym_infer, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates from tinygrad.engine.memory import _internal_memory_planner @@ -194,7 +194,8 @@ def _prepare_jit_inputs(args, kwargs): input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor] names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors] if tensors: Tensor.realize(*tensors) - lbs: list[UOp] = flatten([t.lazydata.lbs for t in tensors]) + # TODO: should we be unpacking multi here? + lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors]) input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None] assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT" st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs] diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 99439c54b9..a9359ce259 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -47,4 +47,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([si.bufs for si in schedule], noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule] + return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule] diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 71cc12492c..c94f18302b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,10 +1,10 @@ import sys, atexit, functools, pickle from collections import defaultdict, deque from dataclasses import dataclass, field -from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views -from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map -from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar +from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers +from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views +from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY from tinygrad.dtype import DType, ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape @@ -37,7 +37,7 @@ tensor_uop_spec = PatternMatcher([ # DETACH and CONTIGUOUS change how we interpret the source UOp # CONTIGUOUS ensures the source UOp realizes - (UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), + (UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype), # COPY # NOTE: the arg here specifies clone=True, which prevents folding same device copy @@ -60,7 +60,6 @@ class ScheduleItem: ast: UOp bufs: tuple[Buffer, ...] metadata: tuple[Metadata, ...] - assign_preloads: tuple[UOp, ...] @property def outputs(self) -> tuple[Buffer, ...]: """Read/write or write only buffers in the schedule.""" @@ -82,9 +81,8 @@ class ScheduleContext: realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata - contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) - becomes_map: dict[UOp, UOp] = field(default_factory=dict) + preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict)) # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. @@ -104,6 +102,7 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}") dtype = buf.dtype.base # ASSIGN already has a target buffer, otherwise we create a new one + assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) # track the underlying tensor uop for this buffer @@ -152,9 +151,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) -# push VIEW to stores +# push VIEW to children view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view() + # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() @@ -200,6 +199,8 @@ to_si = PatternMatcher([ (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), # once images are loaded they become the base dtype (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), + # CONST(VIEW) becomes VALID too, TODO: doesn't have to + (UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)), ]) # LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel @@ -211,25 +212,27 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem: # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals)) # deal with ASSIGN - assign_preloads: list[UOp] = [] if len(ctx.assigns) != 0: + assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] for x in list(sink.toposort)[::-1]: # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") # PRELOAD tells the toposort this kernel should run before ASSIGN if x.op is Ops.PRELOAD: - assign_preloads.append(x.buf_uop) + assign_preloads[x.buf_uop] = None # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous: + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass # if it has a single view and it's equal when you shrink a contig, it's fine - if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask): - raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass + # otherwise, it's not fine + else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) # capture process replay if CAPTURE_PROCESS_REPLAY: with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) - return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), - tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads))) + return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} if CAPTURE_PROCESS_REPLAY: @@ -329,7 +332,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]: # maybe fuse arange with its children for rbuf in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue + if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del ctx.realizes[tr] @@ -354,20 +357,20 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: case _: return None return reduce.const_like(ret) -def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp): - if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti) -def replace_contiguous(ctx:ScheduleContext, alu:UOp): +def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) +def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): new_src = list(alu.src) for i,s in enumerate(alu.src): - if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src + if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) sym = symbolic_simple+PatternMatcher([ # UOp with size 0 is zero (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # DETACH is a NOOP here - (UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), @@ -389,10 +392,10 @@ sym = symbolic_simple+PatternMatcher([ # support for using a contiguous permuted view instead of the parent view if one exists (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER/VIEW from SINK + # remove CONST/BIND/BUFFER from SINK (UPat(Ops.SINK, name="root"), lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), ]) # ** this decides which ops get realized @@ -419,7 +422,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs) return x.view(unwrap(view.st)) def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp): - if not b.device.startswith("DISK"): return None + if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize) return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW))) @@ -440,12 +443,12 @@ do_realize = PatternMatcher([ (UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer), ]) -# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp +# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp): - assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}" - ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg - return ret.valid(unwrap(bind.st)) + assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}" + ctx.var_vals[var.replace(src=())] = val.const_arg + return var def load_realized(ctx:ScheduleContext, b:UOp, st:UOp): # NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN @@ -458,8 +461,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp): return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop())) break_sched = PatternMatcher([ - # CONST is always fused and generated - (UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)), (UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable), # VIEW of BUFFER either becomes a LOAD/STORE or we fuse it (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized), @@ -473,38 +474,43 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None: if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop) for x in op.base.src: if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None - buf_uop.buffer.ref(1) create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)]) # **** movement ops -remove_movement_ops = PatternMatcher([ +remove_movement_ops = merge_views+PatternMatcher([ # NOTE: movement ops are always applied to base (UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))), # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) (UPat(Ops.VIEW, name="view"), lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), - # merge one src views. - (UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)), - # merge unmasked const views - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), - lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) @track_rewrites(named=True) def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec) - tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext()) + tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={}) + # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed + becomes_map: dict[UOp, UOp] = {} + for k,v in tensor_map.items(): + # NOOP + if k.base is v.base: continue + # NOTE: only the base tensors get a BUFFER UOp + if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st)) + # otherwise if it simplified to a CONST the UOp just becomes that CONST + elif v.op is Ops.CONST: becomes_map[k] = v + + # we group the rest of UOps into ScheduleItems rev_tensor_map: dict[UOp, list[UOp]] = {} for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) # add BUFFER uops - sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={}) + sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) # add realizes sink = graph_rewrite(sink, do_realize+create_ctx, ctx) # group realizes into kernels store_groups = group_realizes(ctx) graph_rewrite(sink, break_sched, ctx) - # preschedule realize groups + # create schedule items + map buffers to realized tensors prescheduled: list[ScheduleItem] = [] for store_uops in store_groups: small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops]) @@ -512,16 +518,9 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu prescheduled.append(schedule_uop(small_sink, ctx)) # can only schedule once for buf_uop in store_uops: - for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st)) - - # tensors can become an existing buffer or simplify to a const, no ScheduleItem needed - for k,v in tensor_map.items(): - # NOOP - if k.base is v.base: continue - # NOTE: only the base tensors get a BUFFER UOp - if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st)) - # otherwise if it simplified to a CONST the UOp just becomes that CONST - elif v.op is Ops.CONST: ctx.becomes_map[k] = v + for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + # increment refcount for this buffer + buf_uop.buffer.ref(1) # add kernel children schedule_targets = {out:si for si in prescheduled for out in si.outputs} @@ -529,7 +528,7 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu in_degree: defaultdict[ScheduleItem, int] = defaultdict(int) for si in prescheduled: # realize outputs before a parent is assigned to - parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si) + parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si) for assign in parents_assigns: graph[si].append(assign) in_degree[assign] += 1 @@ -550,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu # confirm everything was scheduled correctly if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") - return schedule, ctx.var_vals, ctx.becomes_map + return schedule, ctx.var_vals, becomes_map diff --git a/tinygrad/function.py b/tinygrad/function.py deleted file mode 100644 index 5527870711..0000000000 --- a/tinygrad/function.py +++ /dev/null @@ -1,203 +0,0 @@ -"""This is where the forwards and backwards passes live.""" -import math -from tinygrad.helpers import argsort -from tinygrad.dtype import dtypes, DType, sum_acc_dtype -from tinygrad.ops import Ops, resolve, sint, UOp -from tinygrad.tensor import Function - -class Contiguous(Function): - def forward(self, x:UOp) -> UOp: return x.contiguous() - def backward(self, grad_output:UOp) -> UOp: return grad_output - -class ContiguousBackward(Function): - def forward(self, x:UOp) -> UOp: return x - def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous() - -class Cast(Function): - def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp: - self.input_dtype, self.bitcast = x.dtype, bitcast - return x.bitcast(dtype) if self.bitcast else x.cast(dtype) - - def backward(self, grad_output:UOp) -> UOp: - if self.bitcast: raise RuntimeError("bitcast cannot backward") - return grad_output.cast(self.input_dtype) - -# ************* unary ops ************* - -class Reciprocal(Function): - def forward(self, x:UOp) -> UOp: - self.ret = x.reciprocal() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret - -class Sin(Function): - def forward(self, x:UOp) -> UOp: - self.x = x - return x.sin() - - def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output - -class Relu(Function): - def forward(self, x:UOp) -> UOp: - self.ret = (x>0).where(x, 0) - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output - -class Log(Function): - def forward(self, x:UOp) -> UOp: - self.x = x - return x.log2() * math.log(2) - - def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x - -class Exp(Function): - def forward(self, x:UOp) -> UOp: - self.ret = (x * (1/math.log(2))).exp2() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output - -class Sqrt(Function): - def forward(self, x:UOp) -> UOp: - self.ret = x.sqrt() - return self.ret - - def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2) - -class Sign(Function): - # NOTE: the x*0 is to match torch behavior without function.py - def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0 - # backward always return 0 to match torch - def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0) - -# ************* binary ops ************* - -class Less(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x tuple[UOp|None, UOp|None]: return None, None - -class Neq(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y) - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None - -class Xor(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x^y - -class BitwiseAnd(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x&y - -class BitwiseOr(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x|y - -class Threefry(Function): - def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed) - -class Add(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x+y - - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: - return grad_output if self.needs_input_grad[0] else None, \ - grad_output if self.needs_input_grad[1] else None - -class Mul(Function): - def forward(self, x:UOp, y:UOp) -> UOp: - self.x, self.y = x, y - return x * y - - def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: - return (self.y * grad_output) if self.needs_input_grad[0] else None, \ - (self.x * grad_output) if self.needs_input_grad[1] else None - -class IDiv(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x // y - -class Mod(Function): - def forward(self, x:UOp, y:UOp) -> UOp: return x % y - -# ************* ternary ops ************* - -class Where(Function): - def forward(self, x:UOp, y:UOp, z:UOp) -> UOp: - self.x = x - return self.x.where(y, z) - - def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]: - return None, \ - self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \ - self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None - -# ************* reduce ops ************* - -class Sum(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.input_shape = x.shape - return x.r(Ops.ADD, axis) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape) - -class Prod(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.x, self.ret = x, x.r(Ops.MUL, axis) - return self.ret - - def backward(self, grad_output:UOp) -> UOp: - return (grad_output * self.ret).expand(self.x.shape) / self.x - -class Max(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis - return self.ret - - def backward(self, grad_output:UOp) -> UOp: - # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype) - div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape) - return (max_is_1s/div) * grad_output.expand(self.x.shape) - -# ************* movement ops ************* - -# NOTE: this is sum in reverse -class Expand(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: - self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so)) - return x.expand(shape) - - def backward(self, grad_output:UOp) -> UOp: - return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype) - -class Reshape(Function): - def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp: - self.input_shape = x.shape - return x.reshape(shape) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape) - -class Permute(Function): - def forward(self, x:UOp, order:tuple[int, ...]) -> UOp: - self.input_order = order - return x.permute(order) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order)) - -class Pad(Function): - def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp: - self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) - return x.pad(arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg) - -class Shrink(Function): - def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp: - self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) - return x.shrink(arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg) - -class Flip(Function): - def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp: - self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))]) - return x.stride(self.arg) - - def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 756a9c9785..86c9f0fa63 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -1,7 +1,7 @@ from typing import cast, Iterator -import math, functools +import math, functools, dataclasses from tinygrad.dtype import dtypes, sum_acc_dtype -from tinygrad.ops import UOp, PatternMatcher, UPat, Ops +from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata from tinygrad.helpers import argsort def reduce_gradient(ctx:UOp, ret:UOp): @@ -28,6 +28,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))), (UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient), (UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)), + (UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)), (UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)), (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)), @@ -36,7 +37,7 @@ pm_gradient = PatternMatcher([ # TODO: this cast can be removed by putting the casts around the EXPAND (UPat(Ops.EXPAND, name="ret"), lambda ctx, ret: (ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)), - + (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # there's no gradient for bitcast (UPat(Ops.BITCAST), lambda ctx: (None,)), ]) @@ -65,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp if v is None: continue if k in grads: grads[k] = grads[k] + v else: grads[k] = v + if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True) return grads diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 090b9178cc..d1eedb5eb5 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -109,6 +109,7 @@ USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0) SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) +CACHELEVEL = ContextVar("CACHELEVEL", 2) @dataclass(frozen=True) class Metadata: @@ -165,7 +166,6 @@ class Profiling(contextlib.ContextDecorator): cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad") CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db"))) -CACHELEVEL = getenv("CACHELEVEL", 2) VERSION = 17 _db_connection = None @@ -186,7 +186,7 @@ def diskcache_clear(): cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"])) def diskcache_get(table:str, key:Union[dict, str, int]) -> Any: - if CACHELEVEL == 0: return None + if CACHELEVEL < 1: return None if isinstance(key, (str,int)): key = {"key": key} conn = db_connection() cur = conn.cursor() @@ -199,7 +199,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any: _db_tables = set() def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False): - if CACHELEVEL == 0: return val + if CACHELEVEL < 1: return val if isinstance(key, (str,int)): key = {"key": key} conn = db_connection() cur = conn.cursor() diff --git a/tinygrad/multi.py b/tinygrad/multi.py index e6e165bb25..c6741b0c98 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -1,8 +1,7 @@ from __future__ import annotations import functools, itertools, operator from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv -from tinygrad.dtype import DType -from tinygrad.ops import Ops, MathTrait, UOp, sint +from tinygrad.ops import Ops, UOp, sint def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]: assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}" @@ -40,133 +39,127 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}") return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))] -class MultiLazyBuffer(MathTrait): - def __init__(self, lbs:list[UOp], axis:int|None, real:list[bool]|None=None): - assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them" - assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}" - self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs) +# ***** multi functions ***** - @property - def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)) +from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites - @property - def size(self): return sum(x.size for x in self.real_lbs) +def alu_multi(root:UOp): + msrcs = root.src + assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}" + assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" - @property - def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r] + # NOTE: they all have to share an axis, we always choose [-1] + axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) + srcs:list[list[UOp]] = [] + not_all_real = not all(all(mlb.real) for mlb in msrcs) + new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real + for mlb in msrcs: + if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src)) + else: + assert axis is not None and bounds is not None + if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds)) + else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds)) + new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)] + new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed? + return UOp.multi(*new_lbs, axis=axis, real=new_real) - @property - def bounds(self): - if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") - return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0))) +def reduce_multi(root:UOp, multi:UOp): + op, axis = root.arg + if multi.axis is not None and multi.axis in axis: + # all-reduce on sharded axes + reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)] + # if all partitions are real, do all_reduce + if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None) + # only one partition is real, keep it + return UOp.multi(*reduced_parts, axis=None, real=multi.real) + # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct + return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real) - def __repr__(self): return f"" +def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: + return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape)) - def copy_to_device(self, device:str) -> UOp: - # if we already have a copy on the device, return that - if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device)) - # copy lbs to device, pad to final shape, and sum - llbs:list[UOp] = [] - for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds): - if not real: continue - pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape))) - llbs.append(lb.copy_to_device(device).pad(pad_arg)) - return functools.reduce(operator.add, llbs) +def reshape_multi(root:UOp, multi:UOp): + arg = root.arg + if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real) + assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)" + arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) + # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards + # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? + new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1 + assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \ + f"reshape cannot move items between shards {multi.shape} -> {root.arg=}" + lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src] + return UOp.multi(*lbs, axis=new_axis, real=multi.real) - # passthroughs - @property - def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) - def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real) - def bitcast(self, dtype:DType): return MultiLazyBuffer([x.bitcast(dtype) for x in self.lbs], self.axis, self.real) - def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real) - def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real) - def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) - def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real) - def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real) - @property - def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort} +def expand_multi(root:UOp, multi:UOp): + # NOTE: this assert isn't needed, sharded axis can have dim 1 + assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}" + return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real) - # elementwise is simple - def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer: - msrcs = (self,)+in_srcs - assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}" - assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}" +def pad_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}" + # pad on shard axis -> fill others with zeros and set real to all True + if multi.axis is not None and root.arg[multi.axis] != (0,0): + # pad back to whole axis, remove real mask + assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time" + dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)] + assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" + return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis) + return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - # NOTE: they all have to share an axis, we always choose [-1] - axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None) - srcs:list[list[UOp]] = [] - not_all_real = not all(all(mlb.real) for mlb in msrcs) - new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real - assert any(new_real), "output contains no real lb" - for mlb in msrcs: - if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs) - else: - assert axis is not None and bounds is not None - if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds)) - else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds)) - new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r} - # NOTE: const dtype should match real - new_dtype = next(iter(new_real_lbs.values())).dtype - return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real) +def permute_multi(root:UOp, multi:UOp): + # all permutes supported! + return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real) - def r(self, op:Ops, axis:tuple[int, ...]) -> MultiLazyBuffer: - if self.axis is not None and self.axis in axis: - # all-reduce on sharded axes - reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)] - # if all partitions are real, do all_reduce - if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None) - # only one partition is real, keep it - return MultiLazyBuffer(reduced_parts, None, self.real) - # reduce on non sharded axes, piecewise is fine. if axis is None this is also correct - return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real) +def shrink_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \ + f"shrinking not supported for {root.arg=}" + if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]): + assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \ + "cannot shrink sharded and non-sharded axis at the same time" + # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real + idx = multi.bounds.index(root.arg[multi.axis]) + # zero out other lbs to not create lb reference + return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)], + axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src)))) + return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src], + axis=multi.axis, real=multi.real) - # *** movement ops *** +def stride_multi(root:UOp, multi:UOp): + assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis" + return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real) - def _shape_to_single_shard(self, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]: - return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape)) +def copy_multi(multi:UOp, device:UOp): + # if we already have a copy on the device, return that + if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg)) + # copy lbs to device, pad to final shape, and sum + llbs:list[UOp] = [] + for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds): + if not real: continue + pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape))) + llbs.append(lb.copy_to_device(device.arg).pad(pad_arg)) + return functools.reduce(operator.add, llbs) - def reshape(self, arg:tuple[sint, ...]): - if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real) - assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)" - arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1)) - # new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards - # todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1? - new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1 - assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}" - lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs] - return MultiLazyBuffer(lbs, new_axis, self.real) +def assign_multi(dest:UOp, src:UOp): + assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}" + return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real) - def pad(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}" - # pad on shard axis -> fill others with zeros and set real to all True - if self.axis is not None and arg[self.axis] != (0,0): - # pad back to whole axis, remove real mask - assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time" - dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)] - assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis" - return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis) - return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real) +def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real) - def expand(self, arg:tuple[sint, ...]): - # NOTE: this assert isn't needed, sharded axis can have dim 1 - assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}" - return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real) +# NOTE: this is the same pattern as Ops.UNROLL +multi_pm = PatternMatcher([ + (UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi), + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi), + (UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi), + (UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi), + (UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi), + (UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi), + (UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi), + (UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi), + (UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi), + (UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi), + (UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi), +]) - def permute(self, arg:tuple[int, ...]): - # all permutes supported! - return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real) - - def shrink(self, arg:tuple[tuple[sint, sint], ...]): - assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}" - if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]): - assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time" - # NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real - idx = self.bounds.index(arg[self.axis]) - # zero out other lbs to not create lb reference - return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))]) - return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs], - self.axis, self.real) - - def stride(self, arg:tuple[int, ...]): - assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis" - return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real) +@track_rewrites(named=True) +def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v} diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index db1d84b345..b7cb9f4359 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -77,9 +77,6 @@ class LARS(Optimizer): def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]: for i, (t, g) in enumerate(zip(self.params, grads)): - # contiguous is needed since the grads can allegedly form a "diamond" - # TODO: fix this in lazy.py - g = g.contiguous() if self.tcoef != 0: r1 = t.detach().square().sum().sqrt() r2 = g.square().sum().sqrt() diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 9400dd2e82..99032cfe63 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T from tinygrad.shape.view import strides_for_shape -from tinygrad.multi import MultiLazyBuffer class TensorIO(io.RawIOBase, BinaryIO): def __init__(self, t: Tensor): @@ -152,9 +151,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr continue if v.shape != state_dict[k].shape: raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.') - if isinstance((mlb:=v.lazydata), MultiLazyBuffer): - if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize() - else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize() + if isinstance(v.device, tuple): + if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize() + else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize() else: v.replace(state_dict[k].to(v.device)).realize() if consume: del state_dict[k] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 0225603cc4..731244918b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait): # the order of these Ops controls the order of the toposort class Ops(FastEnum): # uops that aren't rendered - SINK = auto(); CONTIGUOUS = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 + SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702 # TODO: empty continues to exist because of tensor EMPTY = auto() @@ -150,6 +150,7 @@ class Ops(FastEnum): # device DEVICE = auto() + MULTI = auto() class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} @@ -281,6 +282,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @functools.cached_property def st(self) -> ShapeTracker|None: + from tinygrad.shape.shapetracker import ShapeTracker + if self.op is Ops.MULTI: + return ShapeTracker.from_shape( + tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))) # these ops define a ShapeTracker from the arg if self.op is Ops.VIEW: return self.arg if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg) @@ -294,7 +299,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # only reduce ops are allowed to change shape, everything else derives shape from sources elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg) else: shape = src_sts[0].shape - from tinygrad.shape.shapetracker import ShapeTracker return ShapeTracker.from_shape(shape) @functools.cached_property @@ -350,7 +354,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) def const_like(self, b:ConstLike): # constants can optionally have a DEVICE source - return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) + if self._device is None: return UOp.const(self.dtype, b) + if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None) + return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b) def broadcast(self, count:int): assert self.dtype.count == 1 if count == 1: return self @@ -389,7 +395,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): new_shape = unwrap(self.st).reduce(axis) # TODO: can we split symbolic shape if the reduce axis is not symbolic? - if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \ + # TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi + if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \ prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, axis) @@ -409,6 +416,46 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) def contiguous(self): return self.alu(Ops.CONTIGUOUS) + def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD) + + # *** from MultiLazyBuffer *** + + def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None): + parents = (self,)+more + assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype" + return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents))) + + @property + def bounds(self): + if self.axis is None: raise RuntimeError("bounds is not defined when axis is None") + return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0))) + + @property + def axis(self): + assert self.op is Ops.MULTI + return self.arg[0] + + @property + def real(self): + assert self.op is Ops.MULTI + return self.arg[1] + + @property + def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r] + + def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp: + if axis is None: lbs = [self] * len(devices) + else: + if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") + sz = self.shape[axis] // len(devices) + sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] + lbs, off = [], 0 + for sz in sizes: + lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape)))) + off += sz + sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] + # NOTE: this contiguous is making it impossible for the scheduler to do late const folding + return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis) # *** from LazyBuffer *** @@ -426,7 +473,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val) # otherwise it's just a VIEW(BUFFER) return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st) - def copy_to_device(self, device:str, clone:bool=False) -> UOp: + def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp: # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device) # COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st) @@ -440,8 +487,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True) @property - def lbs(self): return [self] - @property def metadata(self): return all_metadata.get(self, None) # *** uop movement ops *** @@ -470,10 +515,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @staticmethod def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size)) @property - def device(self) -> str: return unwrap(self._device) + def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device)) @functools.cached_property - def _device(self) -> Optional[str]: + def _device(self) -> Optional[str|tuple[str, ...]]: if self.op is Ops.DEVICE: return self.arg + if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src) return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None @property def buf_uop(self) -> UOp: @@ -489,6 +535,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}" if (cret:=buffers.get(self)) is not None: return cret from tinygrad.device import Buffer + assert isinstance(self.device, str), f"buffer not supported on multi {self.device}" buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base) return ret @property @@ -496,7 +543,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized return self.buffer if self.op is Ops.BUFFER else None @property - def is_realized(self) -> bool: return self.base.realized is not None + def is_realized(self) -> bool: + return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None # *** uop Variable stuff *** @@ -639,7 +687,7 @@ def print_uops(uops:list[UOp]): def get_location() -> tuple[str, int]: frm = sys._getframe(1) # find the real frame in the file that has the UPat, TODO: is there a better way to do this? - while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", + while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py", "lowerer.py", "cstyle.py", "linearize.py"}: frm = frm.f_back return frm.f_code.co_filename, frm.f_lineno @@ -775,9 +823,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0) match_stats:dict[UPat, list[Union[int, float]]] = dict() @dataclass(frozen=True) class TrackedGraphRewrite: - loc: tuple[str, int] # location that called graph_rewrite - sink: UOp # the sink input to graph_rewrite - matches: list[tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # before+after of all the matches + loc: tuple[str, int] # location that called graph_rewrite + sink: UOp # the sink input to graph_rewrite + matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches tracked_keys:list[Any] = [] tracked_ctxs:list[list[TrackedGraphRewrite]] = [] _name_cnt:dict[str, int] = {} @@ -808,10 +856,9 @@ class TrackedPatternMatcher(PatternMatcher): match_stats[p][0] += 1 match_stats[p][3] += (et:=time.perf_counter()-st) if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable()) - if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p, et)) + if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p)) return ret # NOTE: if it returns None, we keep trying to match match_stats[p][2] += time.perf_counter()-st - if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0 and len(tracked_ctxs[-1]) != 0: tracked_ctxs[-1][-1].matches.append((uop, None, None, 0)) return None if TRACK_MATCH_STATS: @@ -849,7 +896,7 @@ class RewriteContext: self.replace: dict[UOp, UOp] = {} def top_down_rewrite(self, n:UOp) -> UOp: if (rn := self.replace.get(n)) is not None: return rn - new_src = tuple(map(self.top_down_rewrite, n.src)) + new_src = tuple([self.top_down_rewrite(x) for x in n.src]) new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg) self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n) return ret @@ -857,7 +904,7 @@ class RewriteContext: if (rn := self.replace.get(n)) is not None: return rn new_n: UOp|None = n while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx) - new_src = tuple(map(self.bottom_up_rewrite, last_n.src)) + new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src]) self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg)) return ret @@ -1269,18 +1316,24 @@ Variable = UOp ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]] -# *** uop swizzling *** +# *** UOp merge views and swizzling *** merge_views = PatternMatcher([ - (UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)), - (UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None), + # VIEW(VIEW) merges to a single VIEW + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)), + (UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None), + # merge unmasked const views + (UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)), + lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None), ]) -# push VIEW to loads +# push VIEW to parents view_left = merge_views+PatternMatcher([ - # VIEW before elementwise ops - (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), - lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))), - # early merge VIEW buffer ops - (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), + # VIEW(CONST) becomes VALID + (UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)), + # VIEW before elementwise/buffer ops + (UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)), + lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))), + (UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)), + lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ]) diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 9c20c5b298..cb08f4aa2d 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -198,12 +198,12 @@ class AMDCopyQueue(HWQueue): return self def bind(self, dev:AMDDevice): - if not dev.driverless: return + if not getenv("AMD_SDMA_BIND", 0) or not dev.driverless: return self.binded_device = dev self.hw_page = dev.allocator.alloc((qsz:=round_up(len(self._q), 8)) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True)) hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I") - for i, value in enumerate(self._q): hw_view[i] = value + for i in range(qsz): hw_view[i] = self._q[i] if i < len(self._q) else 0 self.indirect_cmd = [amd_gpu.SDMA_OP_INDIRECT | amd_gpu.SDMA_PKT_INDIRECT_HEADER_VMID(0), *data64_le(self.hw_page.va_addr), qsz, *data64_le(0)] self._q, self.cmd_sizes = hw_view, [len(self.indirect_cmd)] @@ -464,7 +464,7 @@ class PCIIface: iommu_group = HWInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1] except OSError: - if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (not inserted or no-iommu mode is not supported).") + if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (run `sudo modprobe vfio-pci`).") PCIIface.vfio = False # Init vfio for the device @@ -487,21 +487,18 @@ class PCIIface: self.pagemap = HWInterface("/proc/self/pagemap", os.O_RDONLY) self.bar_fds = {bar: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]} - self.adev = AMDev(self.pcidev, self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I')) + self.adev = AMDev(self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I')) self.doorbell_cpu_addr = mv_address(dbell) - libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND) - libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND) + libpciaccess.pci_device_cfg_read_u16(self.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND) + libpciaccess.pci_device_cfg_write_u16(self.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND) # TODO: this is for 7900xtx, the only tested card. self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32, 'array_count': 12, 'simd_arrays_per_engine': 2, 'lds_size_in_kb': 64} def _map_pci_range(self, bar, off=0, addr=0, size=None): - if PCIIface.vfio: - vfio.VFIO_DEVICE_GET_REGION_INFO(self.vfio_dev, reg:=vfio.struct_vfio_region_info(argsz=ctypes.sizeof(vfio.struct_vfio_region_info), index=bar)) - fd, sz, off = self.vfio_dev, size or reg.size, reg.offset + off - else: fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size + fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size return to_mv(fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz) def alloc(self, size:int, host=False, uncached=False, cpu_access=False): diff --git a/tinygrad/runtime/support/am/amdev.py b/tinygrad/runtime/support/am/amdev.py index 37b43d9c7b..2f541d8766 100644 --- a/tinygrad/runtime/support/am/amdev.py +++ b/tinygrad/runtime/support/am/amdev.py @@ -102,19 +102,16 @@ class AMFirmware: class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702 class AMPageTableEntry: - def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv + def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv - def set_table(self, entry_id, pte:AMPageTableEntry, valid=True): - self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0) + def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True): + assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}" - def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True): - f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \ - | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if self.lv != am.AMDGPU_VM_PTB else 0) \ + f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \ + | am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \ | ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \ | (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0) - self.view[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f - - def get_entry(self, entry_id): return self.view[entry_id] + self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f class AMPageTableTraverseContext: def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False): @@ -126,22 +123,23 @@ class AMPageTableTraverseContext: def level_down(self): pt, pte_idx, _ = self.pt_stack[-1] - if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID: - assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" - child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) - else: + if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0: assert self.create_pts, "Not allowed to create new page table" - pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1)) + pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True) + entry = pt.entries[pte_idx] + + assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}" + child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1) self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table))) return self.pt_stack[-1] def _try_free_pt(self) -> bool: pt, _, _ = self.pt_stack[-1] - if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)): + if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)): self.adev.mm.pfree(pt.paddr) parent_pt, parent_pte_idx, _ = self.pt_stack[-2] - parent_pt.set_page(parent_pte_idx, 0x0, valid=False) + parent_pt.set_entry(parent_pte_idx, 0x0, valid=False) return True return False @@ -156,7 +154,7 @@ class AMPageTableTraverseContext: if self.create_pts: while pte_covers > size: pt, pte_idx, pte_covers = self.level_down() else: - while pt.lv!=am.AMDGPU_VM_PTB and (pt.get_entry(pte_idx)&am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down() + while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down() entries = min(size // pte_covers, 512 - pte_idx) assert entries > 0, "Invalid entries" @@ -173,7 +171,7 @@ class AMMemoryManager: self.adev, self.vram_size = adev, vram_size self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device - self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1) + self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1) def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping: assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}" @@ -181,10 +179,10 @@ class AMMemoryManager: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True) for paddr, psize in paddrs: for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize): - frag = 0 if pte_covers == 0x1000 else 0x9 for pte_off in range(pte_cnt): - assert pt.get_entry(pte_idx + pte_off) & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.get_entry(pte_idx + pte_off):#x}" - pt.set_page(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, frag=frag, valid=True) + assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}" + pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, + uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True) # Invalidate TLB after mappings. self.adev.gmc.flush_tlb(ip='GC', vmid=0) @@ -197,8 +195,8 @@ class AMMemoryManager: ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True) for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size): for pte_id in range(pte_idx, pte_idx + pte_cnt): - assert pt.get_entry(pte_id) & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.get_entry(pte_id):#x}" - pt.set_page(pte_id, paddr=0x0, valid=False) + assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}" + pt.set_entry(pte_id, paddr=0x0, valid=False) @staticmethod def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align)) @@ -233,8 +231,8 @@ class AMMemoryManager: def pfree(self, paddr:int): self.pa_allocator.free(paddr) class AMDev: - def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): - self.pcidev, self.devfmt = pcidev, devfmt + def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview): + self.devfmt = devfmt self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions @@ -258,8 +256,8 @@ class AMDev: # To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for # all blocks that are initialized only during the initial AM boot. # To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag. - self.is_booting = True # During boot only boot memory can be allocated. This flag is to validate this. - self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1) + self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this. + self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1) # Memory manager & firmware self.mm = AMMemoryManager(self, self.vram_size) diff --git a/tinygrad/runtime/support/am/ip.py b/tinygrad/runtime/support/am/ip.py index 2b08798825..b27b30579b 100644 --- a/tinygrad/runtime/support/am/ip.py +++ b/tinygrad/runtime/support/am/ip.py @@ -25,6 +25,9 @@ class AM_GMC(AM_IP): self.vm_base = self.adev.mm.va_allocator.base self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1 + # GFX11 has 44-bit address space + self.address_space_mask = (1 << 44) - 1 + self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True) self.hub_initted = {"MM": False, "GC": False} @@ -99,7 +102,13 @@ class AM_GMC(AM_IP): if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}") class AM_SMU(AM_IP): + def __init__(self, adev): + super().__init__(adev) + self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True) + def init(self): + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True) self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True) for clck in [0x00000C94, 0x000204E1, 0x000105DC, 0x00050B76, 0x00070B76, 0x00040898, 0x00060898, 0x000308FD]: @@ -115,6 +124,11 @@ class AM_SMU(AM_IP): self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True) time.sleep(0.5) # 500ms + def read_table(self, table_t, cmd): + self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True) + return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t))) + def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS) + def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout) def _smu_cmn_send_msg(self, msg, param=0): self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg @@ -237,8 +251,9 @@ class AM_IH(AM_IP): def __init__(self, adev): super().__init__(adev) self.ring_size = 512 << 10 - self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0), - (self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)] + def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True), + self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)) + self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)] def interrupt_handler(self): _, rwptr_vm, suf, _ = self.rings[0] @@ -315,6 +330,9 @@ class AM_PSP(AM_IP): self.ring_size = 0x10000 self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True) + self.max_tmr_size = 0x1300000 + self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True) + def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0 def init(self): sos_components_load_order = [ @@ -359,7 +377,7 @@ class AM_PSP(AM_IP): # Load TOC and calculate TMR size self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC]) self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size - self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True) + assert self.tmr_size <= self.max_tmr_size def _ring_create(self): # If the ring is already created, destroy it diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 4fff017555..6897a6cc68 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -115,8 +115,9 @@ class ShapeTracker: def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1] def axis_is_masked(self, axis:int) -> bool: - _, valid = self.to_indexed_uops() - return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] + with Context(TRACK_MATCH_STATS=0): + _, valid = self.to_indexed_uops() + return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 45a9b23749..1b81c74601 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,15 +1,15 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref +import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex +from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap -from tinygrad.multi import MultiLazyBuffer +from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap +from tinygrad.multi import get_multi_map from tinygrad.gradient import compute_gradient from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element -from tinygrad.device import Device, Buffer, BufferSpec +from tinygrad.device import Device, BufferSpec from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars @@ -30,45 +30,23 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None: # link the found UOps back to Tensors. exit early if there's no Tensors to realize # NOTE: this uses all_tensors, but it's fast - fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)] + fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops] if len(fixed_tensors): # potentially rewrite all the discovered Tensors - sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors]) + sink = UOp.sink(*[t.lazydata for t in fixed_tensors]) new_sink = sink.substitute(applied_map) # set the relevant lazydata to the realized UOps for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src): if s is ns: continue - if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src) - else: t.lazydata = ns + t.lazydata = ns -# **** start with two base classes, Tensor and Function **** - -class Function: - def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None): - self.device = device - self.needs_input_grad = [t.requires_grad for t in tensors] - self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False - if self.requires_grad: self.parents = tensors - self.metadata = metadata - - def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}") - def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}") - - @classmethod - def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: - ctx = fxn(x[0].device, *x, metadata=_METADATA.get()) - ret = Tensor.__new__(Tensor) - ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None - ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine - return ret - -import tinygrad.function as F +# **** Tensor helper functions **** def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None): if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg) - return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None) + return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None) def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821 import numpy as np @@ -148,8 +126,7 @@ class Tensor(SimpleMathTrait): np.set_printoptions(precision=4) ``` """ - __slots__ = "lazydata", "requires_grad", "grad", "_ctx" - __deletable__ = ('_ctx',) + __slots__ = "lazydata", "requires_grad", "grad" training: ClassVar[bool] = False no_grad: ClassVar[bool] = False @@ -159,7 +136,7 @@ class Tensor(SimpleMathTrait): return instance def __del__(self): all_tensors.discard(weakref.ref(self)) - def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 + def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): if dtype is not None: dtype = to_dtype(dtype) if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None @@ -172,11 +149,8 @@ class Tensor(SimpleMathTrait): # None (the default) will be updated to True if it's put in an optimizer self.requires_grad: Optional[bool] = requires_grad - # internal variable used for autograd graph construction - self._ctx: Optional[Function] = None - # create a LazyBuffer from the different types of inputs - if isinstance(data, (UOp, MultiLazyBuffer)): + if isinstance(data, UOp): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" # NOTE: this is here because LazyBuffer = UOp if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data) @@ -199,12 +173,12 @@ class Tensor(SimpleMathTrait): data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer - if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") # data might be on a different device - if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device) + if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device) # if device is a tuple, we should have/construct a MultiLazyBuffer - elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata + elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata else: assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" self.lazydata = data @@ -224,8 +198,8 @@ class Tensor(SimpleMathTrait): def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev def __repr__(self): - if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{ld!r}" - else: ld_repr = f"" + ld = self.lazydata + ld_repr = f"" return f"" # Python has a non moving GC, so this should be okay @@ -246,6 +220,17 @@ class Tensor(SimpleMathTrait): @property def dtype(self) -> DType: return self.lazydata.dtype + def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor: + ret = Tensor.__new__(Tensor) + needs_input_grad = [t.requires_grad for t in (self,)+x] + ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None + ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs) + return ret + + def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + lhs,rhs = self._broadcasted(x, reverse) + return lhs._apply_uop(fxn, rhs) + # ***** data handlers **** def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: @@ -254,7 +239,14 @@ class Tensor(SimpleMathTrait): NOTE: A Tensor can only be scheduled once. """ - big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst])) + big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) + + # TODO: move this to scheduler tensor_map pass + if any(x.op is Ops.MULTI for x in big_sink.toposort): + # multi fixup + _apply_map_to_tensors(get_multi_map(big_sink)) + big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst])) + schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) _apply_map_to_tensors(becomes_map) return memory_planner(schedule), var_vals @@ -275,7 +267,6 @@ class Tensor(SimpleMathTrait): Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match. """ # used for replacing a Tensor with a new version of it (potentially with a different device and dtype) - assert getattr(self, '_ctx', None) is None assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}" self.lazydata = x.lazydata return self @@ -287,13 +278,11 @@ class Tensor(SimpleMathTrait): self.contiguous().realize().lazydata.base.realized.copyin(x._data()) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") if self.lazydata is x.lazydata: return self # a self assign is a NOOP # NOTE: we allow cross device assign assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}" assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" - assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? if not self.lazydata.is_realized: return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) @@ -309,7 +298,8 @@ class Tensor(SimpleMathTrait): if 0 in self.shape: return memoryview(bytearray(0)) # NOTE: this realizes on the object from as_buffer being a Python object cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize() - buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized) + buf = cast(UOp, cpu.lazydata).base.realized + assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized" if self.device != "CLANG": buf.options = BufferSpec(nolru=True) return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False) @@ -373,7 +363,6 @@ class Tensor(SimpleMathTrait): """ ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.clone() - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor: @@ -385,7 +374,6 @@ class Tensor(SimpleMathTrait): if not isinstance(device, str): return self.shard(device) ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) if self.grad is not None: ret.grad = self.grad.to(device) - if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret def to_(self, device:Optional[Union[str, tuple[str, ...]]]): @@ -405,18 +393,9 @@ class Tensor(SimpleMathTrait): print(t.shard((t.device, t.device), axis=1).lazydata) ``` """ - assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer" + assert isinstance(self.device, str), "can't shard a MultiLazyBuffer" devices = tuple(Device.canonicalize(x) for x in devices) - if axis is None: lbs = [self.lazydata] * len(devices) - else: - axis = self._resolve_dim(axis) - if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}") - sz = self.shape[axis] // len(devices) - sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))] - lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)] - sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)] - # NOTE: this contiguous is making it impossible for the scheduler to do late const folding - mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis) + mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None) return Tensor(mlb, device=devices, requires_grad=self.requires_grad) def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None): @@ -439,7 +418,7 @@ class Tensor(SimpleMathTrait): def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs): dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float if isinstance(device, tuple): - return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None), + return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None), device, dtype, **kwargs) return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs) @@ -510,7 +489,7 @@ class Tensor(SimpleMathTrait): @staticmethod def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) + x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -750,12 +729,12 @@ class Tensor(SimpleMathTrait): ``` """ dtype = kwargs.pop("dtype", self.dtype) - if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer): + if isinstance(self.device, tuple): if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor") if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device) contiguous = kwargs.pop("contiguous", True) - rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs] - return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) + rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src] + return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs) return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs) # ***** rng hlops ***** @@ -904,7 +883,7 @@ class Tensor(SimpleMathTrait): # ***** toposort and backward pass ***** - def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]: + def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]: """ Compute the gradient of the targets with respect to self. @@ -921,29 +900,17 @@ class Tensor(SimpleMathTrait): assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) rets = [] - for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)): - target_uops = [x.lazydata.lbs[i] for x in targets] - grads = compute_gradient(uop, grad, set(target_uops)) - ret = [] - for x in target_uops: - if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}") - ret.append(y) - rets.append(ret) + target_uops = [x.lazydata for x in targets] + grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops)) + ret = [] + for x in target_uops: + if (y:=grads.get(x)) is None: + if materialize_grads: y = x.const_like(0) + else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}") + ret.append(y) + rets.append(ret) # create returned Tensors - if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] - return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real), - device=t.device) for t,u in zip(targets, zip(*rets))] - - def _deepwalk(self) -> list[Tensor]: - def _walk(node:Tensor, visited:set[Tensor]): - visited.add(node) - # if tensor is not leaf, reset grad - if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None - if ctx: - for i in cast(Function, node._ctx).parents: - if i not in visited: yield from _walk(i, visited) - yield node - return list(_walk(self, set())) + return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor: """ @@ -956,31 +923,13 @@ class Tensor(SimpleMathTrait): print(t.grad.numpy()) ``` """ - toposorted = self._deepwalk() - if gradient is None: - assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor" - # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous - # this is "implicit gradient creation" - gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False) - - toposort_uop = self.lazydata.toposort - assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}" - self.grad = gradient - for t0 in reversed(toposorted): - if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad") - ctx = cast(Function, t0._ctx) - token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None) - grads = ctx.backward(t0.grad.lazydata) - _METADATA.reset(token) - grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None - for g in ([grads] if len(ctx.parents) == 1 else grads)] - for t, g in zip(ctx.parents, grads): - if g is not None and t.requires_grad: - assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" - assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \ - f"grad uop must have a path from self\ngrad uop: {t.lazydata}" - t.grad = g if t.grad is None else (t.grad + g) - if not retain_graph: del t0._ctx + all_uops = self.lazydata.toposort + tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \ + t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad] + # clear contexts + for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)): + assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}" + t.grad = g if t.grad is None else (t.grad + g) return self # ***** movement low level ops ***** @@ -1004,7 +953,7 @@ class Tensor(SimpleMathTrait): # resolve -1 if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) - return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self + return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self def expand(self, shape, *args) -> Tensor: """ @@ -1037,7 +986,7 @@ class Tensor(SimpleMathTrait): """ order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") - return F.Permute.apply(self, order=order_arg) + return self._apply_uop(UOp.permute, arg=order_arg) def flip(self, axis, *args) -> Tensor: """ @@ -1057,7 +1006,7 @@ class Tensor(SimpleMathTrait): """ axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") - return F.Flip.apply(self, axis=axis_arg) + return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))])) def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor: """ @@ -1077,7 +1026,7 @@ class Tensor(SimpleMathTrait): ``` """ if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self - return F.Shrink.apply(self, arg=tuple(shrink_arg)) + return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor: """ @@ -1121,7 +1070,8 @@ class Tensor(SimpleMathTrait): if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX) if mode == "constant": - def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v) + def _constant(x:Tensor,px,v): + return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v)) return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \ _constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value) assert all_int(self.shape), f"does not support symbolic shape {self.shape}" @@ -1278,7 +1228,7 @@ class Tensor(SimpleMathTrait): self._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first - if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous") + if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported") @@ -1611,10 +1561,10 @@ class Tensor(SimpleMathTrait): # ***** reduce ops ***** - def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: + def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) if self.ndim == 0: axis = () - ret = fxn.apply(self, axis=axis) + ret = self._apply_uop(UOp.r, op=op, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1641,7 +1591,7 @@ class Tensor(SimpleMathTrait): print(t.sum(axis=1).numpy()) ``` """ - ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim) + ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1668,7 +1618,7 @@ class Tensor(SimpleMathTrait): print(t.prod(axis=1).numpy()) ``` """ - return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim) + return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim) def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ @@ -1691,7 +1641,7 @@ class Tensor(SimpleMathTrait): print(t.max(axis=1, keepdim=True).numpy()) ``` """ - return self._reduce(F.Max, axis, keepdim) + return self._reduce(Ops.MAX, axis, keepdim) def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not() @@ -2528,7 +2478,7 @@ class Tensor(SimpleMathTrait): print(Tensor([False, True]).logical_not().numpy()) ``` """ - return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True)) + return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True) def neg(self): """ Negates the tensor element-wise. @@ -2542,12 +2492,12 @@ class Tensor(SimpleMathTrait): """ Returns a contiguous tensor. """ - return F.Contiguous.apply(self) + return self._apply_uop(UOp.contiguous) def contiguous_backward(self): """ Inserts a contiguous operation in the backward pass. """ - return F.ContiguousBackward.apply(self) + return self._apply_uop(UOp.contiguous_backward) def log(self): """ Computes the natural logarithm element-wise. @@ -2558,7 +2508,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log().numpy()) ``` """ - return F.Log.apply(self.cast(least_upper_float(self.dtype))) + return self.log2()*math.log(2) def log2(self): """ Computes the base-2 logarithm element-wise. @@ -2569,7 +2519,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 4., 8.]).log2().numpy()) ``` """ - return self.log()/math.log(2) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2) def exp(self): """ Computes the exponential function element-wise. @@ -2580,7 +2530,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp().numpy()) ``` """ - return F.Exp.apply(self.cast(least_upper_float(self.dtype))) + return self.mul(1/math.log(2)).exp2() def exp2(self): """ Computes the base-2 exponential function element-wise. @@ -2591,8 +2541,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., 1., 2., 3.]).exp2().numpy()) ``` """ - return F.Exp.apply(self*math.log(2)) - + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2) def relu(self): """ Applies the Rectified Linear Unit (ReLU) function element-wise. @@ -2603,7 +2552,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy()) ``` """ - return F.Relu.apply(self) + return (self>0).where(self, 0) def sigmoid(self): """ @@ -2639,7 +2588,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).sqrt().numpy()) ``` """ - return F.Sqrt.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt) def rsqrt(self): """ Computes the reciprocal of the square root of the tensor element-wise. @@ -2657,7 +2606,7 @@ class Tensor(SimpleMathTrait): print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy()) ``` """ - return F.Sin.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin) def cos(self): """ Computes the cosine of the tensor element-wise. @@ -2816,7 +2765,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy()) ``` """ - return F.Sign.apply(self) + return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0 def abs(self): """ Computes the absolute value of the tensor element-wise. @@ -2834,7 +2783,7 @@ class Tensor(SimpleMathTrait): print(Tensor([1., 2., 3., 4.]).reciprocal().numpy()) ``` """ - return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype))) + return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal) # ***** activation functions ***** @@ -3112,7 +3061,7 @@ class Tensor(SimpleMathTrait): # for each dimension, check either dim is 1, or it does not change if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)): raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}") - return F.Expand.apply(self.reshape(shape), shape=new_shape) + return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape) def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]: x: Tensor = self @@ -3156,7 +3105,7 @@ class Tensor(SimpleMathTrait): print(t.add(Tensor([[2.0], [3.5]])).numpy()) ``` """ - return F.Add.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.add, x, reverse) def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3197,7 +3146,7 @@ class Tensor(SimpleMathTrait): print(t.mul(Tensor([[-1.0], [2.0]])).numpy()) ``` """ - return F.Mul.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.mul, x, reverse) def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3210,7 +3159,7 @@ class Tensor(SimpleMathTrait): print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ - return F.IDiv.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.idiv, x, reverse) def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3245,7 +3194,7 @@ class Tensor(SimpleMathTrait): ``` """ a, b = self._broadcasted(x, reverse) - return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) + return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3261,7 +3210,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.Xor.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.xor, x, reverse) def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3276,7 +3225,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseAnd.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse) def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ @@ -3291,7 +3240,7 @@ class Tensor(SimpleMathTrait): ``` """ if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported") - return F.BitwiseOr.apply(*self._broadcasted(x, reverse)) + return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse) def bitwise_not(self) -> Tensor: """ @@ -3422,7 +3371,7 @@ class Tensor(SimpleMathTrait): elif isinstance(y, Tensor): y, x = y._broadcasted(x) cond, x = self._broadcasted(x, match_dtype=False) cond, y = cond._broadcasted(y, match_dtype=False) - return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y)) + return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y)) def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self) @@ -3452,9 +3401,9 @@ class Tensor(SimpleMathTrait): def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x)) def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x)) - def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False)) - def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True)) - def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x)) + def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False) + def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True) + def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False) def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override] @@ -3800,8 +3749,8 @@ class Tensor(SimpleMathTrait): """ if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype): # NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around - return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt) - return self if self.dtype == dt else F.Cast.apply(self, dtype=dt) + return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt) + return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt) def bitcast(self, dtype:DTypeLike) -> Tensor: """ @@ -3826,7 +3775,7 @@ class Tensor(SimpleMathTrait): tmp = self.bitcast(old_uint) if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype) return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype) - return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self + return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self def float(self) -> Tensor: """ @@ -4000,5 +3949,5 @@ def _metadata_wrapper(fn): if TRACEMETA >= 1: for name, fn in inspect.getmembers(Tensor, inspect.isfunction): - if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue + if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn))) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 0f94ebe655..08b7f4d00a 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -301,17 +301,17 @@ const kernelListParent = document.querySelector(".container.kernel-list-parent"); const kernelList = document.querySelector(".container.kernel-list"); kernelList.innerHTML = ""; - kernels.forEach((k, i) => { + kernels.forEach(([key, items], i) => { const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "", style: "overflow-x: auto; cursor: initial;" }); if (i === currentKernel) { requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" })); } - const p = Object.assign(document.createElement("p"), { id: `kernel-${k[0].kernel_name}`, innerText: k[0].kernel_name ?? "UNPARENTED", + const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerText: key ?? "UNPARENTED", style: "cursor: pointer;"}); kernelUl.appendChild(p) - k.forEach((u, j) => { - const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.upats.length}`, key: `uop-rewrite-${j}`, + items.forEach((u, j) => { + const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`, className: (j === currentUOp && i == currentKernel) ? "active" : "" }) if (j === currentUOp) { requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" })); @@ -460,7 +460,7 @@ event.preventDefault() currentUOp = 0; currentRewrite = 0; - currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1) + currentKernel = Math.min(kernels.length-1, currentKernel+1); return main() } } @@ -486,7 +486,7 @@ if (event.key == "ArrowDown") { event.preventDefault() currentRewrite = 0; - const totalUOps = kernels[currentKernel].length-1; + const totalUOps = kernels[currentKernel][1].length-1; currentUOp = Math.min(totalUOps, currentUOp+1) main() } diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index c95fb5f800..02dbee0a5e 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -2,8 +2,7 @@ import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal from http.server import HTTPServer, BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse -from dataclasses import asdict, dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, TypedDict from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp from tinygrad.codegen.kernel import Kernel @@ -13,54 +12,30 @@ from tinygrad.dtype import dtypes uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff", **{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80", Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"} -# ** API spec +# VIZ API -@dataclass -class GraphRewriteMetadata: - """Overview of a tracked rewrite to viz the sidebar""" - loc: tuple[str, int] - """File_path, Lineno""" - code_line: str - """The Python line calling graph_rewrite""" - kernel_name: str - """The kernel calling graph_rewrite""" - upats: list[tuple[tuple[str, int], str, float]] - """List of all the applied UPats""" +class GraphRewriteMetadata(TypedDict): + loc: tuple[str, int] # [path, lineno] calling graph_rewrite + match_count: int # total match count in this context -@dataclass class GraphRewriteDetails(GraphRewriteMetadata): - """Full details about a single call to graph_rewrite""" - uops: list[UOp] - graphs: list[dict] - """Sink at every step of graph_rewrite + the json serialized version""" - diffs: list[list[str]] - """.diff style before and after of the rewritten UOp child""" - changed_nodes: list[list[int]] - """Nodes that changed at every step of graph_rewrite""" - kernel_code: Optional[str] - """The program after all rewrites""" - -# ** API functions + graphs: list[dict] # JSON serialized UOp at every rewrite step + uops: list[str] # strigified UOp at every rewrite step + diffs: list[list[str]] # string diff of the single UOp that changed + changed_nodes: list[list[int]] # the changed UOp id + all its parents ids + code_line: str # source code calling graph_rewrite + kernel_code: str|None # optionally render the final kernel code + upats: list[tuple[tuple[str, int], str]] # NOTE: if any extra rendering in VIZ fails, we don't crash def pcall(fxn:Callable[..., str], *args, **kwargs) -> str: try: return fxn(*args, **kwargs) except Exception as e: return f"ERROR: {e}" -def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]: - kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {} - for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"): - name = to_function_name(k.name) if isinstance(k, Kernel) else str(k) - for ctx in ctxs: - if ctx.sink.op is Ops.CONST: continue - upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None] - kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats))) - return list(kernels.values()) - def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]: assert isinstance(x, UOp) graph: dict[int, tuple[str, str, list[int], str, str]] = {} @@ -80,27 +55,27 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]: else: label += f"\n{x.op.name}{idx} {x.arg}" graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph + +def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]: + return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k), + [{"loc": v.loc, "match_count": len(v.matches)} for v in vals]) for k,vals in zip(keys, contexts)] + @functools.lru_cache(None) def _prg(k:Kernel): return k.to_program().src def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails: - g = GraphRewriteDetails(**asdict(metadata), uops=[ctx.sink], diffs=[], changed_nodes=[], - kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[]) + ret:GraphRewriteDetails = {"uops":[pcall(str, sink:=ctx.sink)], "graphs":[uop_to_json(sink)], "code_line":lines(ctx.loc[0])[ctx.loc[1]-1].strip(), + "kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "diffs":[], "upats":[], "changed_nodes":[], **metadata} replaces: dict[UOp, UOp] = {} - g.graphs.append(uop_to_json(sink:=g.uops[0])) - for i,(u0,u1,upat,_) in enumerate(tqdm(ctx.matches)): - # if the match didn't result in a rewrite we move forward - if u1 is None: continue + for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches)): replaces[u0] = u1 - # first, rewrite this UOp with the current rewrite + all the matches in replaces new_sink = sink.substitute(replaces) - # sanity check - if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}") - # update ret data - g.graphs.append(new_sink_js:=uop_to_json(new_sink)) - g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js]) - g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) - g.uops.append(sink:=new_sink) - return g + ret["graphs"].append(new_sink_js:=uop_to_json(new_sink)) + ret["changed_nodes"].append([id(x) for x in u1.toposort if id(x) in new_sink_js]) + ret["diffs"].append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines()))) + ret["upats"].append((upat.location, upat.printable())) + # TODO: this is O(n^2)! + ret["uops"].append(str(sink:=new_sink)) + return ret # Profiler API devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {} @@ -150,13 +125,9 @@ class Handler(BaseHTTPRequestHandler): elif url.path == "/kernels": query = parse_qs(url.query) if (qkernel:=query.get("kernel")) is not None: - g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])]) - # TODO: this is O(n^2)! - uops_strs = [pcall(str,x) for x in tqdm(g.uops)] - # NOTE: don't use asdict because it's reserializing the uops - jret: Any = {"loc": g.loc, "code_line": g.code_line, "kernel_name": g.kernel_name, "upats": g.upats, - "uops": uops_strs, "graphs": g.graphs, "diffs": g.diffs, "changed_nodes": g.changed_nodes, "kernel_code": g.kernel_code} - else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels] + kidx, ridx = int(qkernel[0]), int(query["idx"][0]) + jret:Any = get_details(contexts[0][kidx], contexts[1][kidx][ridx], kernels[int(qkernel[0])][1][int(query["idx"][0])]) + else: jret = kernels ret, content_type = json.dumps(jret).encode(), "application/json" elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json" else: status_code = 404 @@ -198,10 +169,11 @@ if __name__ == "__main__": contexts, profile = load_pickle(args.kernels), load_pickle(args.profile) + # NOTE: this context is a tuple of list[keys] and list[values] kernels = get_metadata(*contexts) if contexts is not None else [] if getenv("FUZZ_VIZ"): - ret = [get_details(*args) for v in tqdm(kernels) for args in v] + ret = [get_details(contexts[0][i], contexts[1][i][j], args) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])] print(f"fuzzed {len(ret)} rewrite details") perfetto_profile = to_perfetto(profile) if profile is not None else None