mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
14
.github/workflows/benchmark.yml
vendored
14
.github/workflows/benchmark.yml
vendored
@@ -299,6 +299,10 @@ jobs:
|
||||
run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
- name: Run 10 MLPerf Bert training steps (6 gpu)
|
||||
# TODO: remove DISABLE_DROPOUT once dropout is fixed
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NVIDIA Training)
|
||||
@@ -309,9 +313,10 @@ jobs:
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
train_cifar_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_resnet.txt
|
||||
train_resnet_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_bert.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -492,6 +497,10 @@ jobs:
|
||||
run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
- name: Run 10 MLPerf Bert training steps (6 gpu)
|
||||
# TODO: remove DISABLE_DROPOUT once dropout is fixed
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD Training)
|
||||
@@ -502,9 +511,10 @@ jobs:
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
train_cifar_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_resnet.txt
|
||||
train_resnet_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_bert.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
|
||||
38
.github/workflows/test.yml
vendored
38
.github/workflows/test.yml
vendored
@@ -619,6 +619,44 @@ jobs:
|
||||
if: matrix.backend=='amd'
|
||||
run: python -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20
|
||||
|
||||
wintests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm]
|
||||
|
||||
name: Tests on Windows (${{ matrix.backend }})
|
||||
runs-on: windows-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.12
|
||||
- name: Cache python packages
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}\Lib\site-packages
|
||||
key: windows-${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Install dependencies
|
||||
run: pip install --user -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
env:
|
||||
DEBUG: 5
|
||||
LLVM: 1
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
run: |
|
||||
python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run pytest
|
||||
env:
|
||||
DEBUG: 5
|
||||
LLVM: 1
|
||||
run: python -m pytest -n=auto test/test_tiny.py --durations=20
|
||||
|
||||
#testunicorn:
|
||||
# name: ARM64 unicorn Test
|
||||
# runs-on: ubuntu-latest
|
||||
|
||||
@@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not
|
||||
|
||||
## Frontend
|
||||
|
||||
Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md).
|
||||
Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md).
|
||||
|
||||
The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base.
|
||||
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
::: tinygrad.function
|
||||
options:
|
||||
members: [
|
||||
"Contiguous",
|
||||
"ContiguousBackward",
|
||||
"Cast",
|
||||
"Neg",
|
||||
"Reciprocal",
|
||||
"Sin",
|
||||
"Relu",
|
||||
"Log",
|
||||
"Exp",
|
||||
"Sqrt",
|
||||
"Sigmoid",
|
||||
"Sign",
|
||||
"Less",
|
||||
"Eq",
|
||||
"Xor",
|
||||
"Add",
|
||||
"Sub",
|
||||
"Mul",
|
||||
"Div",
|
||||
"Where",
|
||||
"Sum",
|
||||
"Max",
|
||||
"Expand",
|
||||
"Reshape",
|
||||
"Permute",
|
||||
"Pad",
|
||||
"Shrink",
|
||||
"Flip",
|
||||
]
|
||||
show_source: false
|
||||
@@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
|
||||
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
|
||||
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
|
||||
@@ -35,8 +34,6 @@ class UnsyncedBatchNorm:
|
||||
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
|
||||
|
||||
xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
|
||||
batch_mean, batch_invstd = self.calc_stats(xr)
|
||||
ret = xr.batchnorm(
|
||||
|
||||
@@ -247,11 +247,11 @@ if __name__ == "__main__":
|
||||
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
|
||||
args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
|
||||
elif args.size == "70B":
|
||||
subdir = "Llama-3.1-Nemotron-70B-Instruct-HF"
|
||||
args.model = fetch("https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir)
|
||||
subdir = "DeepSeek-R1-Distill-Llama-70B"
|
||||
args.model = fetch("https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir)
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=subdir)
|
||||
for i in range(30):
|
||||
fetch(f"https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model-{i+1:05d}-of-00030.safetensors?download=true", f"model-{i+1:05d}-of-00030.safetensors", subdir=subdir)
|
||||
for i in range(17):
|
||||
fetch(f"https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model-{i+1:05d}-of-000017.safetensors?download=true", f"model-{i+1:05d}-of-000017.safetensors", subdir=subdir)
|
||||
|
||||
assert args.model is not None, "please provide --model option"
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ def get_mlperf_bert_config():
|
||||
"intermediate_size": 4096,
|
||||
"max_position_embeddings": 512,
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 24,
|
||||
"num_hidden_layers": getenv("BERT_LAYERS", 24),
|
||||
"type_vocab_size": 2,
|
||||
"vocab_size": 30522
|
||||
}
|
||||
|
||||
@@ -786,7 +786,8 @@ def train_rnnt():
|
||||
pass
|
||||
|
||||
@TinyJit
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
|
||||
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
optimizer.zero_grad()
|
||||
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
@@ -802,18 +803,15 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
return loss.realize()
|
||||
return loss.realize(), global_norm.realize()
|
||||
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return {
|
||||
"masked_lm_accuracy": masked_lm_accuracy.realize(),
|
||||
"next_sentence_accuracy": seq_relationship_accuracy.realize(),
|
||||
"masked_lm_loss": masked_lm_loss.realize(),
|
||||
"next_sentence_loss": next_sentence_loss.realize()
|
||||
}
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
|
||||
|
||||
def train_bert():
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
@@ -949,7 +947,7 @@ def train_bert():
|
||||
previous_step = None
|
||||
if ckpt:=getenv("RESUME", ""):
|
||||
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
|
||||
start_step = int(scheduler_wd.epoch_counter.numpy().item())
|
||||
start_step = int(scheduler_wd.epoch_counter.item())
|
||||
print(f"resuming from {ckpt} at step {start_step}")
|
||||
|
||||
if RUNMLPERF:
|
||||
@@ -975,7 +973,7 @@ def train_bert():
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
|
||||
|
||||
@@ -992,7 +990,7 @@ def train_bert():
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.numpy().item()
|
||||
loss = loss.item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
@@ -1002,7 +1000,7 @@ def train_bert():
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
||||
if WANDB:
|
||||
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
|
||||
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
||||
|
||||
@@ -1037,12 +1035,10 @@ def train_bert():
|
||||
GlobalCounters.reset()
|
||||
st = time.time()
|
||||
|
||||
eval_result: dict[str, Tensor] = eval_step_bert(model,
|
||||
lm_acc, clsf_acc, lm_loss, clsf_loss = eval_step_bert(model,
|
||||
eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
|
||||
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
|
||||
|
||||
lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
|
||||
lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
|
||||
lm_acc, clsf_acc, lm_loss, clsf_loss = lm_acc.item(), clsf_acc.item(), lm_loss.item(), clsf_loss.item()
|
||||
|
||||
eval_lm_losses.append(lm_loss)
|
||||
eval_clsf_losses.append(clsf_loss)
|
||||
@@ -1059,7 +1055,7 @@ def train_bert():
|
||||
return
|
||||
|
||||
if getenv("RESET_STEP", 1): eval_step_bert.reset()
|
||||
del eval_data, eval_result
|
||||
del eval_data
|
||||
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
|
||||
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
|
||||
avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)
|
||||
|
||||
@@ -4,7 +4,7 @@ export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
export BASEDIR="/raid/datasets/wiki"
|
||||
|
||||
|
||||
167
extra/amdpci/am_smi.py
Normal file
167
extra/amdpci/am_smi.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import time, mmap, sys, shutil, os, glob
|
||||
from tinygrad.helpers import to_mv, DEBUG, colored, ansilen
|
||||
from tinygrad.runtime.autogen import libc
|
||||
from tinygrad.runtime.autogen.am import smu_v13_0_0
|
||||
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
|
||||
AM_VERSION = 0xA0000002
|
||||
|
||||
def bold(s): return f"\033[1m{s}\033[0m"
|
||||
|
||||
def color_temp(temp):
|
||||
if temp >= 87: return colored(f"{temp:>4}", "red")
|
||||
elif temp >= 80: return colored(f"{temp:>4}", "yellow")
|
||||
return f"{temp:>4}"
|
||||
|
||||
def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan")
|
||||
|
||||
def draw_bar(percentage, width=40, fill='█', empty='░'):
|
||||
filled_width = int(width * percentage)
|
||||
bar = fill * filled_width + empty * (width - filled_width)
|
||||
return f'[{bar}] {percentage*100:5.1f}%'
|
||||
|
||||
def same_line(strs:list[list[str]], split=8) -> list[str]:
|
||||
ret = []
|
||||
max_width_in_block = [max(ansilen(line) for line in block) for block in strs]
|
||||
max_height = max(len(block) for block in strs)
|
||||
for i in range(max_height):
|
||||
line = []
|
||||
for bid, block in enumerate(strs):
|
||||
if i < len(block): line.append(block[i] + ' ' * (split + max_width_in_block[bid] - ansilen(block[i])))
|
||||
else: line.append(' ' * (split + max_width_in_block[bid]))
|
||||
ret.append(' '.join(line))
|
||||
return ret
|
||||
|
||||
def get_bar0_size(pcibus):
|
||||
resource_file = f"/sys/bus/pci/devices/{pcibus}/resource"
|
||||
if not os.path.exists(resource_file): raise FileNotFoundError(f"Resource file not found: {resource_file}")
|
||||
|
||||
with open(resource_file, "r") as f: lines = f.readlines()
|
||||
bar0_info = lines[0].split()
|
||||
if len(bar0_info) < 3: raise ValueError("Unexpected resource file format for BAR0.")
|
||||
|
||||
start_hex, end_hex, _flags = bar0_info
|
||||
return int(end_hex, 16) - int(start_hex, 16) + 1
|
||||
|
||||
class AMSMI(AMDev):
|
||||
def __init__(self, pcibus, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.pcibus = pcibus
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
|
||||
if self.reg("regSCRATCH_REG7").read() != AM_VERSION:
|
||||
raise Exception(f"Unsupported AM version: {self.reg('regSCRATCH_REG7').read():x}")
|
||||
|
||||
self.is_booting, self.smi_dev = True, True
|
||||
self.partial_boot = True # do not init anything
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
|
||||
# Initialize IP blocks
|
||||
self.soc21:AM_SOC21 = AM_SOC21(self)
|
||||
self.gmc:AM_GMC = AM_GMC(self)
|
||||
self.ih:AM_IH = AM_IH(self)
|
||||
self.psp:AM_PSP = AM_PSP(self)
|
||||
self.smu:AM_SMU = AM_SMU(self)
|
||||
|
||||
class SMICtx:
|
||||
def __init__(self):
|
||||
self.devs = []
|
||||
self.opened_pcidevs = []
|
||||
self.opened_pci_resources = {}
|
||||
self.prev_lines_cnt = 0
|
||||
|
||||
def _open_am_device(self, pcibus):
|
||||
if pcibus not in self.opened_pci_resources:
|
||||
bar_fds = {bar: os.open(f"/sys/bus/pci/devices/{pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
|
||||
bar_size = {0: get_bar0_size(pcibus), 2: os.fstat(bar_fds[2]).st_size, 5: os.fstat(bar_fds[5]).st_size}
|
||||
|
||||
def map_pci_range(bar):
|
||||
return to_mv(libc.mmap(0, bar_size[bar], mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, bar_fds[bar], 0), bar_size[bar])
|
||||
self.opened_pci_resources[pcibus] = (map_pci_range(0), None, map_pci_range(5).cast('I'))
|
||||
|
||||
try:
|
||||
self.devs.append(AMSMI(pcibus, *self.opened_pci_resources[pcibus]))
|
||||
except Exception as e:
|
||||
if DEBUG >= 2: print(f"Failed to open AM device {pcibus}: {e}")
|
||||
return
|
||||
|
||||
self.opened_pcidevs.append(pcibus)
|
||||
if DEBUG >= 2: print(f"Opened AM device {pcibus}")
|
||||
|
||||
def rescan_devs(self):
|
||||
pattern = os.path.join('/tmp', 'am_*.lock')
|
||||
for d in [f[8:-5] for f in glob.glob(pattern)]:
|
||||
if d not in self.opened_pcidevs:
|
||||
self._open_am_device(d)
|
||||
|
||||
for d in self.devs:
|
||||
if d.reg("regSCRATCH_REG7").read() != AM_VERSION:
|
||||
self.devs.remove(d)
|
||||
self.opened_pcidevs.remove(d.pcibus)
|
||||
os.system('clear')
|
||||
if DEBUG >= 2: print(f"Removed AM device {d.pcibus}")
|
||||
|
||||
def collect(self): return {d: d.smu.read_metrics() for d in self.devs}
|
||||
|
||||
def draw(self):
|
||||
terminal_width, _ = shutil.get_terminal_size()
|
||||
|
||||
dev_metrics = self.collect()
|
||||
dev_content = []
|
||||
for dev, metrics in dev_metrics.items():
|
||||
device_line = [f"PCIe device: {bold(dev.pcibus)}"] + [""]
|
||||
activity_line = [f"GFX Activity {draw_bar(metrics.SmuMetrics.AverageGfxActivity / 100, 50)}"] \
|
||||
+ [f"UCLK Activity {draw_bar(metrics.SmuMetrics.AverageUclkActivity / 100, 50)}"] + [""]
|
||||
|
||||
# draw_metrics_table(metrics, dev)
|
||||
temps_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_TEMP_e__enumvalues.items()
|
||||
if k < smu_v13_0_0.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0]
|
||||
temps_table = ["=== Temps (C) ==="] + [f"{name:<15}: {color_temp(metrics.SmuMetrics.AvgTemperature[k])}" for k, name in temps_keys]
|
||||
|
||||
voltage_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_SVI_PLANE_e__enumvalues.items() if k < smu_v13_0_0.SVI_PLANE_COUNT]
|
||||
power_table = ["=== Power ==="] \
|
||||
+ [f"Fan Speed: {metrics.SmuMetrics.AvgFanRpm} RPM"] \
|
||||
+ [f"Fan Power: {metrics.SmuMetrics.AvgFanPwm} %"] \
|
||||
+ [f"Power: {metrics.SmuMetrics.AverageSocketPower}W " +
|
||||
draw_bar(metrics.SmuMetrics.AverageSocketPower / metrics.SmuMetrics.dGPU_W_MAX, 16)] \
|
||||
+ ["", "=== Voltages ==="] + [f"{name:<24}: {color_voltage(metrics.SmuMetrics.AvgVoltage[k])}" for k, name in voltage_keys]
|
||||
|
||||
frequency_table = ["=== Frequencies ===",
|
||||
f"GFXCLK Target : {metrics.SmuMetrics.AverageGfxclkFrequencyTarget} MHz",
|
||||
f"GFXCLK PreDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPreDs} MHz",
|
||||
f"GFXCLK PostDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPostDs} MHz",
|
||||
f"FCLK PreDs : {metrics.SmuMetrics.AverageFclkFrequencyPreDs} MHz",
|
||||
f"FCLK PostDs : {metrics.SmuMetrics.AverageFclkFrequencyPostDs} MHz",
|
||||
f"MCLK PreDs : {metrics.SmuMetrics.AverageMemclkFrequencyPreDs} MHz",
|
||||
f"MCLK PostDs : {metrics.SmuMetrics.AverageMemclkFrequencyPostDs} MHz",
|
||||
f"VCLK0 : {metrics.SmuMetrics.AverageVclk0Frequency} MHz",
|
||||
f"DCLK0 : {metrics.SmuMetrics.AverageDclk0Frequency} MHz",
|
||||
f"VCLK1 : {metrics.SmuMetrics.AverageVclk1Frequency} MHz",
|
||||
f"DCLK1 : {metrics.SmuMetrics.AverageDclk1Frequency} MHz"]
|
||||
|
||||
dev_content.append(device_line + activity_line + same_line([temps_table, power_table, frequency_table]))
|
||||
|
||||
raw_text = 'AM Monitor'.center(terminal_width) + "\n" + "=" * terminal_width + "\n\n"
|
||||
for i in range(0, len(dev_content), 2):
|
||||
if i + 1 < len(dev_content): raw_text += '\n'.join(same_line([dev_content[i], dev_content[i+1]]))
|
||||
else: raw_text += '\n'.join(dev_content[i])
|
||||
if i + 2 < len(dev_content): raw_text += "\n" + "=" * terminal_width + "\n\n"
|
||||
|
||||
sys.stdout.write(f'\033[{self.prev_lines_cnt}A')
|
||||
sys.stdout.flush()
|
||||
print(raw_text)
|
||||
|
||||
self.prev_lines_cnt = len(raw_text.splitlines()) + 2
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
os.system('clear')
|
||||
smi_ctx = SMICtx()
|
||||
while True:
|
||||
smi_ctx.rescan_devs()
|
||||
smi_ctx.draw()
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt: print("Exiting...")
|
||||
3
extra/amdpci/setup_python_cap.sh
Executable file
3
extra/amdpci/setup_python_cap.sh
Executable file
@@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
PYTHON_PATH=$(readlink -f $(which python3))
|
||||
sudo setcap 'cap_dac_override,cap_sys_rawio,cap_sys_admin=ep' $PYTHON_PATH
|
||||
2
extra/amdpci/setup_vfio.sh
Executable file
2
extra/amdpci/setup_vfio.sh
Executable file
@@ -0,0 +1,2 @@
|
||||
#!/bin/bash
|
||||
sudo modprobe vfio-pci disable_idle_d3=1
|
||||
@@ -63,15 +63,15 @@ class BertForPretraining:
|
||||
|
||||
def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
|
||||
valid = masked_lm_ids != 0
|
||||
masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1)
|
||||
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
|
||||
masked_lm_predictions = prediction_logits.argmax(-1)
|
||||
masked_lm_correct = (masked_lm_predictions == masked_lm_ids) * valid
|
||||
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
|
||||
|
||||
seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1)
|
||||
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
|
||||
seq_relationship_predictions = seq_relationship_logits.argmax(-1)
|
||||
seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels)
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
|
||||
return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
|
||||
return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss
|
||||
|
||||
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
|
||||
|
||||
@@ -1,41 +1,15 @@
|
||||
import functools, io, math
|
||||
from typing import cast, Literal
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType
|
||||
from tinygrad.helpers import prod, flatten, make_tuple
|
||||
from extra.onnx import dtype_parse, _cached_to_python_const
|
||||
import numpy as np
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
# ***** Property/Graph Ops *****
|
||||
def Identity(x:Tensor): return x
|
||||
# TODO: fix buffer_parse
|
||||
def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
|
||||
def Sub(x:Tensor|int, other:Tensor): return x - other # some test has input as int
|
||||
def Less(x:Tensor,y:Tensor): return x < y
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
||||
def Greater(x:Tensor,y:Tensor): return x > y
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
||||
def Equal(x:Tensor,y:Tensor): return x == y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
def BitwiseOr(x:Tensor, y:Tensor): return x | y
|
||||
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
|
||||
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
|
||||
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
||||
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
||||
|
||||
# **************** Simple Ops ****************
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
|
||||
def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype)
|
||||
|
||||
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None,
|
||||
value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None,
|
||||
value_string:str|None=None, value_strings:list[str]|None=None):
|
||||
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
|
||||
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
|
||||
if value is not None: return value
|
||||
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
|
||||
@@ -44,21 +18,79 @@ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:
|
||||
if value_string is not None or value_strings is not None and sparse_value is not None:
|
||||
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
|
||||
|
||||
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
|
||||
|
||||
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
||||
try: import PIL.Image
|
||||
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
|
||||
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
||||
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
||||
if pixel_format == "RGB": return Tensor(np.array(img))
|
||||
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
||||
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
||||
|
||||
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
||||
|
||||
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
||||
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
|
||||
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
|
||||
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
|
||||
|
||||
def Size(data:Tensor): return data.numel()
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
||||
|
||||
# ***** Unary Ops (math) *****
|
||||
def Not(x:Tensor): return x.logical_not()
|
||||
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
|
||||
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
|
||||
|
||||
# ***** Unary Ops (activation) *****
|
||||
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
||||
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
||||
Softmax = {1:Softmax_1, 13:Softmax_13}
|
||||
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
|
||||
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
|
||||
def FastGelu(x:Tensor, bias:Tensor|None=None):
|
||||
# this is tanh approximated
|
||||
return (x + bias).gelu() if bias is not None else x.gelu()
|
||||
# TODO: fix this
|
||||
def PRelu(X:Tensor, slope:Tensor):
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
|
||||
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
|
||||
return (X > 0).where(X, X * slope)
|
||||
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
|
||||
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
|
||||
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
|
||||
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
||||
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
|
||||
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): # noqa: A002
|
||||
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
|
||||
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
||||
|
||||
# ***** Unary Ops (broadcasted) *****
|
||||
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
|
||||
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
|
||||
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
|
||||
def Less(x:Tensor,y:Tensor): return x < y
|
||||
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
||||
def Greater(x:Tensor,y:Tensor): return x > y
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
||||
def Equal(x:Tensor,y:Tensor): return x == y
|
||||
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
|
||||
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
|
||||
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
|
||||
def BitwiseOr(x:Tensor,y:Tensor): return x | y
|
||||
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
|
||||
# ***** Casting Ops *****
|
||||
# TODO: saturate
|
||||
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
|
||||
|
||||
# ***** Reduce Ops *****
|
||||
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
|
||||
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
|
||||
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
|
||||
@@ -80,27 +112,27 @@ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_wit
|
||||
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
|
||||
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
|
||||
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
|
||||
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
|
||||
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
|
||||
|
||||
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
|
||||
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
|
||||
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
|
||||
def Size(data:Tensor): return data.numel()
|
||||
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
||||
# ***** Movement Ops *****
|
||||
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
|
||||
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
|
||||
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
|
||||
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
|
||||
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
|
||||
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
|
||||
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
|
||||
def Not(x:Tensor): return x.logical_not()
|
||||
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
||||
|
||||
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
||||
# TODO: add test for when axes is None
|
||||
def Squeeze(data:Tensor, axes:list[int]|None=None):
|
||||
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
||||
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
||||
|
||||
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
|
||||
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
||||
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
|
||||
axes = axes or list(range(data.ndim))
|
||||
steps = steps or [1]*data.ndim
|
||||
@@ -113,83 +145,9 @@ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0)
|
||||
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
|
||||
return data.split(split, axis)
|
||||
|
||||
# TODO: add test for when axes is None
|
||||
def Squeeze(data:Tensor, axes:list[int]|None=None):
|
||||
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
|
||||
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
|
||||
|
||||
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
|
||||
|
||||
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
|
||||
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
|
||||
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
|
||||
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
|
||||
|
||||
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
|
||||
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
|
||||
|
||||
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
|
||||
if value is None: value = Tensor(0, dtype=dtypes.float32)
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
|
||||
|
||||
# **************** Complex Ops ****************
|
||||
|
||||
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
|
||||
|
||||
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
|
||||
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
|
||||
if reverse: X = X.flip(axis)
|
||||
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
||||
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
||||
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
||||
|
||||
# TODO: this is copied from tinygrad/nn/__init__.py
|
||||
# spatial is from opset 7 and has since been removed
|
||||
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
||||
training_mode:int=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).rsqrt()
|
||||
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
invstd = (input_var + epsilon).rsqrt()
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
||||
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
axis = tuple(range(2, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axes, keepdim=True)
|
||||
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
||||
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
|
||||
def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])]))
|
||||
|
||||
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
|
||||
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
|
||||
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
|
||||
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
|
||||
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
||||
|
||||
def _onnx_pads_to_tiny_pads(pads):
|
||||
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
|
||||
return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
|
||||
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
|
||||
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
|
||||
value = constant_value or value
|
||||
@@ -198,6 +156,21 @@ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[
|
||||
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
|
||||
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
|
||||
|
||||
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
||||
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
for s, x in zip(shape, axes or range(t.ndim)):
|
||||
tx = t.shape[x]
|
||||
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
||||
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
||||
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
# ***** Processing Ops *****
|
||||
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
|
||||
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
|
||||
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
|
||||
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
|
||||
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
||||
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
||||
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
||||
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
||||
@@ -206,25 +179,22 @@ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
||||
|
||||
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
|
||||
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
|
||||
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
|
||||
|
||||
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
|
||||
storage_order:int=0, strides:list[int]|int=1):
|
||||
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
|
||||
ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode)
|
||||
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
|
||||
# tests expect indices with int64 dtype
|
||||
# TODO: if there are repeated values, this is wrong
|
||||
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
|
||||
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
|
||||
|
||||
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads))
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
|
||||
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
|
||||
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
|
||||
|
||||
# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
|
||||
# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py
|
||||
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
|
||||
strides:list[int]|int=1):
|
||||
@@ -247,80 +217,28 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shap
|
||||
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
|
||||
return ret.pad(_onnx_pads_to_tiny_pads(pads))
|
||||
|
||||
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
||||
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
||||
def SpaceToDepth(X:Tensor, blocksize:int):
|
||||
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
||||
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
||||
Dropout = {6:Dropout_6, 7:Dropout_7}
|
||||
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
|
||||
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
|
||||
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
|
||||
return ret
|
||||
|
||||
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
||||
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
||||
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
||||
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
|
||||
|
||||
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
|
||||
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
||||
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
|
||||
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
|
||||
if reverse: X = X.flip(axis)
|
||||
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
|
||||
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
|
||||
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
|
||||
|
||||
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
return x.nll_loss(target, weight, ignore_index, reduction)
|
||||
|
||||
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
log_probs = scores.log_softmax(1)
|
||||
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
||||
|
||||
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
||||
|
||||
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
||||
|
||||
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
||||
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
x_shape, i_shape = x.shape, indices.shape
|
||||
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
||||
# NOTE: each batched dim of both input and indices are equal
|
||||
x = x.reshape(b, *x.shape[batch_dims:])
|
||||
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||
x = x.contiguous()
|
||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
|
||||
u = u.squeeze(0)
|
||||
if reduction == "none": x[i] = u
|
||||
elif reduction == "add": x[i] += u
|
||||
elif reduction == "mul": x[i] *= u
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
|
||||
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
||||
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
||||
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
|
||||
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
|
||||
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
|
||||
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
|
||||
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
|
||||
@@ -377,116 +295,43 @@ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, si
|
||||
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
||||
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
||||
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
|
||||
|
||||
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
|
||||
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
|
||||
for s, x in zip(shape, axes or range(t.ndim)):
|
||||
tx = t.shape[x]
|
||||
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
|
||||
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
|
||||
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth = int(depth[0] if isinstance(depth, list) else depth)
|
||||
indices = (indices < 0).where(indices+depth, indices)
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
if axis < 0: axis += inp.ndim
|
||||
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
|
||||
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
|
||||
|
||||
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
|
||||
if axis < 0: axis += x.ndim
|
||||
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
|
||||
if block_size == 0:
|
||||
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
|
||||
return scale.reshape(shape), zero_point.reshape(shape)
|
||||
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
|
||||
# ***** Neural Network Ops *****
|
||||
# TODO: try to factor out common implementations for these ops
|
||||
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
|
||||
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
|
||||
training_mode:int=0, spatial=1, is_test=0):
|
||||
if training_mode:
|
||||
x_detached = X.detach()
|
||||
current_mean = x_detached.mean(axis=(0,2,3))
|
||||
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
current_var = (y*y).mean(axis=(0,2,3))
|
||||
current_invstd = current_var.add(epsilon).rsqrt()
|
||||
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous()
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor):
|
||||
assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops"
|
||||
y = (y / y_scale + y_zero_point).round()
|
||||
return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype)
|
||||
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1,
|
||||
kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1):
|
||||
x = x.int() - x_zero_point
|
||||
w = w.int() - w_zero_point
|
||||
y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides)
|
||||
y_scale = y_scale / (x_scale * w_scale)
|
||||
return _quantize_linear(y, y_scale, y_zero_point)
|
||||
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point:Tensor|int) -> Tensor:
|
||||
a = a.int() - a_zero_point
|
||||
b = b.int() - b_zero_point
|
||||
y = Tensor.matmul(a, b, acc_dtype=dtypes.int32)
|
||||
y_scale = y_scale / (a_scale * b_scale)
|
||||
return _quantize_linear(y, y_scale, y_zero_point)
|
||||
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None,
|
||||
auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None,
|
||||
pads: int | list[int] = 0, strides: int | list[int] = 1) -> Tensor:
|
||||
x_int = x.int() - x_zero_point
|
||||
w_int = w.int() - w_zero_point
|
||||
return Conv(x_int, w_int, B, auto_pad, dilations, group, kernel_shape, pads, strides)
|
||||
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
||||
A_int = A.int() - a_zero_point
|
||||
B_int = B.int() - b_zero_point
|
||||
return Tensor.matmul(A_int, B_int, acc_dtype=dtypes.int32)
|
||||
|
||||
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
|
||||
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
|
||||
try: import PIL.Image
|
||||
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
|
||||
img = PIL.Image.open(io.BytesIO(encoded_stream))
|
||||
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
||||
if pixel_format == "RGB": return Tensor(np.array(img))
|
||||
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
||||
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
|
||||
|
||||
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
|
||||
N, _, *spatial_dims = size
|
||||
def generate_grid(steps):
|
||||
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
# **************** com.microsoft Ops ****************
|
||||
running_mean = input_mean * momentum + current_mean * (1 - momentum)
|
||||
running_var = input_var * momentum + current_var * (1 - momentum)
|
||||
|
||||
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
|
||||
invstd = (input_var + epsilon).rsqrt()
|
||||
return X.batchnorm(scale, B, input_mean, invstd)
|
||||
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
|
||||
axis = tuple(range(2, x.ndim))
|
||||
mean = x.mean(axis=axis, keepdim=True)
|
||||
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
|
||||
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
|
||||
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
|
||||
assert stash_type == 1, "only float32 is supported"
|
||||
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
|
||||
mean = x.mean(axis=axes, keepdim=True)
|
||||
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
|
||||
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
|
||||
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
|
||||
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
|
||||
x = x + skip + bias
|
||||
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
|
||||
|
||||
def FastGelu(x:Tensor, bias:Tensor|None=None):
|
||||
# this is tanh approximated
|
||||
return (x + bias).gelu() if bias is not None else x.gelu()
|
||||
|
||||
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
|
||||
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
|
||||
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
|
||||
@@ -513,6 +358,45 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embeddin
|
||||
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
|
||||
return out, None, embedding_sum
|
||||
|
||||
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth = int(depth[0] if isinstance(depth, list) else depth)
|
||||
indices = (indices < 0).where(indices+depth, indices)
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
|
||||
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
|
||||
def SpaceToDepth(X:Tensor, blocksize:int):
|
||||
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
|
||||
|
||||
# Reimplemented here because you need legacy RNG for passing ONNX tests.
|
||||
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
|
||||
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
|
||||
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
|
||||
Dropout = {6:Dropout_6, 7:Dropout_7}
|
||||
|
||||
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
|
||||
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
|
||||
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
|
||||
|
||||
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
return x.nll_loss(target, weight, ignore_index, reduction)
|
||||
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
|
||||
log_probs = scores.log_softmax(1)
|
||||
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
|
||||
|
||||
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
|
||||
N, _, *spatial_dims = size
|
||||
def generate_grid(steps):
|
||||
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
|
||||
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
|
||||
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
|
||||
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
|
||||
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
|
||||
|
||||
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
|
||||
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
|
||||
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
|
||||
@@ -552,37 +436,132 @@ def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:
|
||||
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
|
||||
return out, present if past is not None else out
|
||||
|
||||
# ***** Indexing Ops *****
|
||||
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
|
||||
|
||||
def Gather(x:Tensor, indices:Tensor, axis:int=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
|
||||
|
||||
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
|
||||
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
x_shape, i_shape = x.shape, indices.shape
|
||||
b = math.prod(x.shape[dim] for dim in range(batch_dims))
|
||||
# NOTE: each batched dim of both input and indices are equal
|
||||
x = x.reshape(b, *x.shape[batch_dims:])
|
||||
indices = indices.reshape(b, *indices.shape[batch_dims:])
|
||||
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
|
||||
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
|
||||
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
|
||||
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
|
||||
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
|
||||
x = x.contiguous()
|
||||
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
|
||||
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
|
||||
u = u.squeeze(0)
|
||||
if reduction == "none": x[i] = u
|
||||
elif reduction == "add": x[i] += u
|
||||
elif reduction == "mul": x[i] *= u
|
||||
else: raise NotImplementedError("reduction doesn't support max or min")
|
||||
return x
|
||||
|
||||
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
|
||||
def GatherElements(x:Tensor, indices:Tensor, axis:int):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
return x.gather(axis, indices)
|
||||
|
||||
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
|
||||
if axis is None:
|
||||
inp = inp.flatten()
|
||||
axis = 0
|
||||
if axis < 0: axis += inp.ndim
|
||||
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
# ***** Quantization Ops *****
|
||||
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
|
||||
|
||||
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
|
||||
if axis < 0: axis += x.ndim
|
||||
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
|
||||
if block_size == 0:
|
||||
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
|
||||
return scale.reshape(shape), zero_point.reshape(shape)
|
||||
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
|
||||
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
|
||||
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
|
||||
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
|
||||
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
|
||||
|
||||
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
|
||||
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
|
||||
return op(*adjusted_inputs, **opts)
|
||||
|
||||
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in quantized int
|
||||
out = _op_integer(op, inputs, zero_points, **opts)
|
||||
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
|
||||
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
|
||||
# op execution is done in float32
|
||||
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
|
||||
out = op(*dequantized_inputs, **opts)
|
||||
assert dtypes.is_float(out.dtype), "op should've done math in float"
|
||||
out_quantized = (out / out_scale).round() + out_zero_point
|
||||
return _clamp_cast(out_quantized, out_zero_point.dtype)
|
||||
|
||||
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
|
||||
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
|
||||
|
||||
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
|
||||
y_zero_point:Tensor|int) -> Tensor:
|
||||
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
|
||||
|
||||
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
|
||||
a = a.int() - a_zero_point
|
||||
b = b.int() - b_zero_point
|
||||
c = (a * a_scale + b * b_scale)
|
||||
return _quantize_linear(c, c_scale, c_zero_point)
|
||||
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
|
||||
|
||||
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
|
||||
assert channels_last in {0, 1}
|
||||
if channels_last == 1: X = X.permute(0, 2, 3, 1)
|
||||
X = (X.int() - x_zero_point) * x_scale
|
||||
y = GlobalAveragePool(X)
|
||||
return _quantize_linear(y, y_scale, y_zero_point)
|
||||
assert channels_last == 0, "unsure what this does"
|
||||
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
|
||||
|
||||
# **************** ai.onnx.preview.training Ops ****************
|
||||
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
|
||||
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
|
||||
|
||||
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
|
||||
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
|
||||
|
||||
# ***** Training Ops *****
|
||||
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
|
||||
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
|
||||
|
||||
from tinygrad.nn.optim import Adam as TinyAdam
|
||||
from tinygrad.nn.optim import SGD
|
||||
|
||||
def onnx_training(input_group_size):
|
||||
def _decorator(func):
|
||||
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
||||
def _onnx_training(input_group_size):
|
||||
def __decorator(func):
|
||||
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
|
||||
R = R.detach()
|
||||
groups = len(inputs) // input_group_size
|
||||
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
|
||||
return tuple(flatten(zip(*ret)))
|
||||
return __wrapper
|
||||
return _decorator
|
||||
return ___wrapper
|
||||
return __decorator
|
||||
|
||||
@onnx_training(3)
|
||||
@_onnx_training(3)
|
||||
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
|
||||
X, G, H = (i.detach() for i in inputs)
|
||||
grad = norm_coefficient * X + G
|
||||
@@ -592,9 +571,10 @@ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:flo
|
||||
X.assign(X.detach() - r * up)
|
||||
return [X, H]
|
||||
|
||||
@onnx_training(4)
|
||||
@_onnx_training(4)
|
||||
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
|
||||
norm_coefficient_post:float=0.0):
|
||||
norm_coefficient_post:float=0.0):
|
||||
from tinygrad.nn.optim import Adam as TinyAdam
|
||||
X, G, V, H = inputs
|
||||
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
|
||||
X.grad = norm_coefficient * X.detach() + G
|
||||
@@ -610,8 +590,9 @@ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, eps
|
||||
X = (1 - norm_coefficient_post) * X
|
||||
return [X, V, H]
|
||||
|
||||
@onnx_training(3)
|
||||
@_onnx_training(3)
|
||||
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
|
||||
from tinygrad.nn.optim import SGD
|
||||
X, G, V = inputs
|
||||
G, V = G.detach(), V.detach()
|
||||
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
|
||||
@@ -620,6 +601,6 @@ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str,
|
||||
opt.step()
|
||||
return [X, V]
|
||||
|
||||
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **__):
|
||||
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
|
||||
intermediate_tensors[y].backward()
|
||||
return tuple([t.grad for t in inputs])
|
||||
|
||||
@@ -22,7 +22,6 @@ nav:
|
||||
- Runtime: runtime.md
|
||||
- Developer:
|
||||
- Intro: developer/developer.md
|
||||
- Function (autodiff): developer/function.md
|
||||
- UOp: developer/uop.md
|
||||
- Runtime:
|
||||
- developer/runtime.md
|
||||
|
||||
@@ -44,8 +44,8 @@ def main():
|
||||
else:
|
||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True)
|
||||
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False)
|
||||
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus)
|
||||
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus)
|
||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
|
||||
|
||||
6
test/external/external_fuzz_ampt.py
vendored
6
test/external/external_fuzz_ampt.py
vendored
@@ -25,7 +25,7 @@ class AMPTFuzzer:
|
||||
_vaddr = va.va_addr + _offset
|
||||
|
||||
for i in range(_n_ptes):
|
||||
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
|
||||
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
|
||||
self.d.vram[pte['paddr']] = pattern # Mark this page
|
||||
assert pte['valid'] == 1
|
||||
|
||||
@@ -41,7 +41,7 @@ class AMPTFuzzer:
|
||||
frags_l = list(ctx.next(contig_range))
|
||||
for f_offset, f_pt, f_pte_idx, f_n_ptes, f_pte_covers in frags_l:
|
||||
for j in range(f_n_ptes):
|
||||
f_pte = helper_read_entry_components(f_pt.get_entry(f_pte_idx + j))
|
||||
f_pte = helper_read_entry_components(f_pt.entries[f_pte_idx + j])
|
||||
assert f_pte['valid'] == 1
|
||||
assert f_pte['paddr'] == start_paddr+f_offset+j*f_pte_covers, f"paddr {f_pte['paddr']:#x} not {start_paddr+f_offset+j*f_pte_covers:#x}"
|
||||
|
||||
@@ -53,7 +53,7 @@ class AMPTFuzzer:
|
||||
def verify_memory(self, pages, pattern: int) -> bool:
|
||||
for _offset, _pt, _pte_idx, _n_ptes, _pte_covers in pages:
|
||||
for i in range(_n_ptes):
|
||||
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
|
||||
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
|
||||
if self.d.vram[pte['paddr']] != pattern: return False
|
||||
if pte['valid'] == 0: return False
|
||||
|
||||
|
||||
12
test/external/external_test_am.py
vendored
12
test/external/external_test_am.py
vendored
@@ -3,7 +3,9 @@ from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraver
|
||||
from tinygrad.helpers import mv_address
|
||||
|
||||
class FakeGMC:
|
||||
def __init__(self): self.vm_base = 0x0
|
||||
def __init__(self):
|
||||
self.vm_base = 0x0
|
||||
self.address_space_mask = (1 << 44) - 1
|
||||
def flush_tlb(self, *args, **kwargs): pass
|
||||
|
||||
class FakePCIRegion:
|
||||
@@ -14,7 +16,7 @@ class FakePCIDev:
|
||||
|
||||
class FakeAM:
|
||||
def __init__(self):
|
||||
self.is_booting = True
|
||||
self.is_booting, self.smi_dev = True, False
|
||||
self.pcidev = FakePCIDev()
|
||||
self.vram = memoryview(bytearray(4 << 30))
|
||||
self.gmc = FakeGMC()
|
||||
@@ -72,7 +74,7 @@ class TestAMPageTable(unittest.TestCase):
|
||||
for tup in results:
|
||||
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
|
||||
for i in range(_n_ptes):
|
||||
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
|
||||
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
|
||||
assert pte['paddr'] == va + _offset + i * _pte_covers, f"Expected paddr {pte['paddr']:#x} to be {va + _offset + i * _pte_covers:#x}"
|
||||
assert pte['valid'] == 1
|
||||
|
||||
@@ -81,7 +83,7 @@ class TestAMPageTable(unittest.TestCase):
|
||||
for tup in results:
|
||||
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
|
||||
for i in range(_n_ptes):
|
||||
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
|
||||
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
|
||||
assert pte['paddr'] == 0
|
||||
assert pte['valid'] == 0
|
||||
|
||||
@@ -113,7 +115,7 @@ class TestAMPageTable(unittest.TestCase):
|
||||
for tup in ctx.next(0x100000):
|
||||
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
|
||||
for i in range(_n_ptes):
|
||||
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
|
||||
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
|
||||
assert pte['paddr'] == 0xdead0000 + _offset + i * _pte_covers, f"paddr {pte['paddr']:#x} not {0xdead0000 + _offset + i * _pte_covers:#x}"
|
||||
assert pte['valid'] == 1
|
||||
|
||||
|
||||
10
test/external/process_replay/process_replay.py
vendored
10
test/external/process_replay/process_replay.py
vendored
@@ -2,7 +2,7 @@
|
||||
# compare kernels created by HEAD against master
|
||||
from collections import defaultdict
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, List, Tuple, Union, cast
|
||||
from typing import Callable, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
@@ -33,15 +33,15 @@ class ProcessReplayWarning(Warning): pass
|
||||
def recreate_sched(ast:UOp) -> UOp:
|
||||
# NOTE: process replay isn't meant to actually schedule anything
|
||||
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
# NOTE: replay with the captured renderer, not the one in master
|
||||
return k.opts.render(name, cast(List,k.to_program().uops))
|
||||
return k.opts.render(name, cast(list,k.to_program().uops))
|
||||
|
||||
# *** diff a "good" recreation against the generated version
|
||||
|
||||
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||
def diff(offset:int, name:str, fxn:Callable) -> tuple[int, int]|bool:
|
||||
if early_stop.is_set(): return True
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
@@ -95,7 +95,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
|
||||
cur.close()
|
||||
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
|
||||
inputs = list(range(0, row_count, PAGE_SIZE))
|
||||
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
|
||||
ret: list[tuple[int, int]|bool] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
|
||||
pool.close()
|
||||
pool.join()
|
||||
pool.terminate()
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -40,6 +40,7 @@ class TestTrain(unittest.TestCase):
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(CI, "slow")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
def test_efficientnet(self):
|
||||
model = EfficientNet(0)
|
||||
X = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
@@ -56,6 +57,7 @@ class TestTrain(unittest.TestCase):
|
||||
train_one_step(model,X,Y)
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
|
||||
def test_transformer(self):
|
||||
# this should be small GPT-2, but the param count is wrong
|
||||
# (real ff_dim is 768*4)
|
||||
|
||||
@@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase):
|
||||
# llama3 is 128256
|
||||
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
|
||||
emb = nn.Embedding(vocab_size, embed_size)
|
||||
emb_w = emb.weight.numpy()
|
||||
# TODO: why is a new realize needed here
|
||||
emb_w = emb.weight.realize().numpy()
|
||||
x = Tensor([1,2,3,4])
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
z = emb(x).realize()
|
||||
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3)
|
||||
self.assertEqual(GlobalCounters.kernel_count, 2)
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
with torch.no_grad():
|
||||
|
||||
@@ -8,9 +8,11 @@ from tinygrad.ops import UOp
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def tensors_allocated():
|
||||
gc.collect()
|
||||
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
|
||||
|
||||
def bufs_allocated():
|
||||
gc.collect()
|
||||
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
|
||||
|
||||
class TestGC(unittest.TestCase):
|
||||
@@ -31,7 +33,7 @@ class TestGC(unittest.TestCase):
|
||||
base = tensors_allocated()
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert (tensors_allocated()-base == 5)
|
||||
assert (tensors_allocated()-base == 4)
|
||||
(a*b).mean().backward()
|
||||
assert (tensors_allocated()-base == 6)
|
||||
del b
|
||||
|
||||
@@ -134,5 +134,39 @@ class TestImageDType(unittest.TestCase):
|
||||
print(lst)
|
||||
assert not np.any(np.isnan(lst))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
class TestImageRealization(unittest.TestCase):
|
||||
def test_image_dtype_expand(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize()
|
||||
self.assertEqual(it_expanded.dtype, dtypes.float32)
|
||||
|
||||
def test_image_dtype_expand_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4))
|
||||
it2 = it_expanded.sum(3).realize()
|
||||
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
def test_image_alu_children(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
|
||||
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
|
||||
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous()
|
||||
alu1 = it_expanded+1
|
||||
alu2 = it_expanded.sum(3)
|
||||
it_expanded.realize()
|
||||
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
|
||||
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4)))
|
||||
alu1.realize()
|
||||
self.assertEqual(alu1.dtype, dtypes.float32)
|
||||
# alu2 is back in image because it fits the dtype again
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
alu2.realize()
|
||||
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# should upcast the two Tensor.stacks
|
||||
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
|
||||
|
||||
@unittest.expectedFailure # requires contiguous folding
|
||||
def test_masked_upcast_wino_full(self):
|
||||
with Context(WINO=1):
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
|
||||
@@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||
opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"])
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
|
||||
|
||||
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local")
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import unittest, functools, random
|
||||
from typing import List
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
from tinygrad.multi import all_reduce, MultiLazyBuffer
|
||||
from tinygrad.multi import all_reduce
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -30,7 +30,7 @@ N = 128
|
||||
def _test_allreduce(t:Tensor):
|
||||
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
|
||||
ts = t.shard(devices_4, 0).realize()
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0))
|
||||
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0))
|
||||
b.realize()
|
||||
return aa, b
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_to(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.to_(devices_2)
|
||||
for lb in X.lazydata.lbs:
|
||||
for lb in X.lazydata.src:
|
||||
assert lb.shape == (256,)
|
||||
(X + X).realize()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_shard(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.shard_(devices_2, 0)
|
||||
for lb in X.lazydata.lbs:
|
||||
for lb in X.lazydata.src:
|
||||
assert lb.shape == (128,)
|
||||
(X + X).realize()
|
||||
|
||||
@@ -218,9 +218,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
|
||||
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
|
||||
with Context(RING=0):
|
||||
a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
|
||||
with Context(RING=2):
|
||||
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
|
||||
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
|
||||
diff = a - b
|
||||
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
|
||||
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
|
||||
@@ -344,9 +344,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow")
|
||||
def test_data_parallel_resnet(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
from resnet import ResNet18
|
||||
from extra.models.resnet import ResNet18
|
||||
|
||||
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
|
||||
fake_image_sharded = fake_image.shard(devices_2, axis=0)
|
||||
@@ -356,16 +354,14 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(m): p.shard_(devices_2).realize()
|
||||
GlobalCounters.reset()
|
||||
shard_output = m(fake_image_sharded).log_softmax().realize()
|
||||
assert shard_output.lazydata.lbs[0].shape == (1, 1000)
|
||||
assert shard_output.lazydata.lbs[1].shape == (1, 1000)
|
||||
assert shard_output.lazydata.src[0].shape == (1, 1000)
|
||||
assert shard_output.lazydata.src[1].shape == (1, 1000)
|
||||
shard_output_np = shard_output.numpy()
|
||||
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
import sys, pathlib
|
||||
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
|
||||
from resnet import ResNet18
|
||||
from extra.models.resnet import ResNet18
|
||||
from tinygrad.nn.optim import LARS
|
||||
|
||||
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
|
||||
@@ -386,12 +382,35 @@ class TestMultiTensor(unittest.TestCase):
|
||||
GlobalCounters.reset()
|
||||
optimizer.zero_grad()
|
||||
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
|
||||
assert shard_output.lazydata.axis is None
|
||||
shard_output.backward()
|
||||
shard_grad = m.conv1.weight.grad.numpy()
|
||||
# sometimes there is zeros in these grads... why?
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_assign_kv_cache_multi(self):
|
||||
bsz, max_context = 2, 8
|
||||
|
||||
class Attn:
|
||||
@TinyJit
|
||||
def __call__(self, xk:Tensor, start_pos:UOp):
|
||||
seqlen = xk.shape[1]
|
||||
if not hasattr(self, "cache_k"):
|
||||
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize()
|
||||
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
|
||||
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
|
||||
|
||||
attn = Attn()
|
||||
xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous()
|
||||
attn(xk, 0)
|
||||
for i in range(3,6):
|
||||
# copied from LLaMA
|
||||
start_pos = Variable("start_pos", 1, max_context).bind(i)
|
||||
xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous()
|
||||
attn(xk, start_pos)
|
||||
|
||||
out = attn.cache_k.flatten().numpy()
|
||||
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
|
||||
|
||||
def test_multi_tensor_jit_param(self):
|
||||
@TinyJit
|
||||
def jf(a, b) -> Tensor:
|
||||
@@ -532,13 +551,13 @@ class TestMultiTensor(unittest.TestCase):
|
||||
t4 = t2.reshape((26, 105,))
|
||||
|
||||
for t in [t0, t1, t2, t3, t4]:
|
||||
assert t.lazydata.axis == 1
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
|
||||
assert t.lazydata.axis == 1
|
||||
|
||||
# test shape-one axis
|
||||
t5 = t4.reshape((26, 1, 105))
|
||||
assert t5.lazydata.axis == 2
|
||||
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
|
||||
assert t5.lazydata.axis == 2
|
||||
|
||||
# test split and rejoin to the right and reshape to the left
|
||||
t5 = t0.reshape((2, 13, 3, 5, 7))
|
||||
@@ -553,7 +572,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
|
||||
# test no left join
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
t0.reshape((26*15,7)).schedule()
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
@@ -588,6 +607,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
# don't allow assigns that change axes
|
||||
t_none.assign(t_zero)
|
||||
t_none.schedule()
|
||||
|
||||
def test_init_rand_with_multiple_devices_fail(self):
|
||||
# init rand with multi device is not allowed
|
||||
@@ -627,6 +647,16 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
def test_rand_like_from_alu(self):
|
||||
# TODO: fix this, which will also fix multi device dropout
|
||||
a = Tensor.ones(4, 4).shard(devices_2, axis=0)
|
||||
with self.assertRaises(AssertionError):
|
||||
(a + a).rand_like()
|
||||
|
||||
b = Tensor.empty(4, 4).shard(devices_2, axis=None)
|
||||
with self.assertRaises(AssertionError):
|
||||
(a + b).rand_like()
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
@@ -635,7 +665,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.device, t2.device)
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs))
|
||||
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src))
|
||||
|
||||
def test_rand_like_none_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
@@ -718,7 +748,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
devices = (d0, d1, d2, d3)
|
||||
t = Tensor.zeros(16, 16).contiguous()
|
||||
t.shard_(devices, axis=0).realize()
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs])
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src])
|
||||
|
||||
@unittest.skip("this is unreliable on OSX")
|
||||
def test_clone(self):
|
||||
@@ -774,25 +804,31 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
# sharded axis shrink on non-device boundry is not allowed
|
||||
a = t.shrink(((0, 3), (0, 8)))
|
||||
a.schedule()
|
||||
with self.assertRaises(AssertionError):
|
||||
# cannot shrink sharded and non-sharded axis at the same time
|
||||
a = t.shrink(((0, 2), (2, 4)))
|
||||
a.schedule()
|
||||
|
||||
a = t.shrink(((0, 2), (0, 8)))
|
||||
a.schedule()
|
||||
assert a.shape == (2, 8)
|
||||
assert a.lazydata.real == [True, False, False, False]
|
||||
assert a.lazydata.real == (True, False, False, False)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# cannot pad sharded and non-sharded axis at the same time
|
||||
p = a.pad(((0, 6), (0, 1)))
|
||||
p.schedule()
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
# can only pad to whole axis
|
||||
p = a.pad(((1, 5), (0, 0)))
|
||||
p.schedule()
|
||||
|
||||
p = a.pad(((0, 6), (0, 0)))
|
||||
p.schedule()
|
||||
assert p.shape == (8, 8)
|
||||
assert p.lazydata.real == [True, True, True, True]
|
||||
assert p.lazydata.real == (True, True, True, True)
|
||||
|
||||
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
|
||||
def test_ops(self, dtype):
|
||||
@@ -804,8 +840,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
a = t.shrink(((0+2*i,2+2*i),None))
|
||||
b = Tensor(t.numpy()[0+2*i:2+2*i])
|
||||
assert a.shape == b.shape == (2, 8)
|
||||
assert a.lazydata.real == [i==j for j in range(4)]
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
assert a.lazydata.real == tuple(i==j for j in range(4))
|
||||
# cast
|
||||
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
|
||||
|
||||
@@ -859,24 +895,27 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
np.testing.assert_equal((a+a).numpy(), na+na)
|
||||
np.testing.assert_equal((b+b).numpy(), nb+nb)
|
||||
|
||||
@unittest.skip("why didn't this work?")
|
||||
def test_add_two_partitions(self):
|
||||
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
|
||||
|
||||
a = t.shrink(((2, 4), None))
|
||||
b = t.shrink(((6, 8), None))
|
||||
self.assertEqual(a.lazydata.real, [False, True, False, False])
|
||||
self.assertEqual(b.lazydata.real, [False, False, False, True])
|
||||
na = t.numpy()[2:4]
|
||||
nb = t.numpy()[6:8]
|
||||
np.testing.assert_equal(a.numpy(), na)
|
||||
np.testing.assert_equal(b.numpy(), nb)
|
||||
self.assertEqual(a.lazydata.real, (False, True, False, False))
|
||||
self.assertEqual(b.lazydata.real, (False, False, False, True))
|
||||
with self.assertRaises(AssertionError):
|
||||
# cannot add directly
|
||||
c = a + b
|
||||
c.schedule()
|
||||
|
||||
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
|
||||
self.assertEqual(c.lazydata.real, [True, True, True, True])
|
||||
c.realize()
|
||||
self.assertEqual(c.lazydata.real, (True, True, True, True))
|
||||
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
|
||||
np.testing.assert_equal(c.numpy(), expected)
|
||||
|
||||
@@ -937,8 +976,9 @@ class TestBatchNorm(unittest.TestCase):
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
bn_ts = []
|
||||
for bound, bn in zip(x.lazydata.bounds, self.bns):
|
||||
xi = x.shrink((bound, None, None, None))
|
||||
each = x.shape[0]//len(self.bns)
|
||||
for i, bn in enumerate(self.bns):
|
||||
xi = x.shrink(((each*(i), each*(i+1)), None, None, None))
|
||||
bni = bn(xi)
|
||||
bn_ts.append(bni)
|
||||
return bn_ts[0].cat(*bn_ts[1:])
|
||||
|
||||
@@ -352,12 +352,15 @@ class TestOps(unittest.TestCase):
|
||||
def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y)
|
||||
|
||||
def test_cmp_ne_backwards(self):
|
||||
# new grad zeroes these out
|
||||
"""
|
||||
t1 = torch.ones(4, requires_grad=True)
|
||||
t2 = torch.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (t1 != t2).sum().backward)
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward)
|
||||
"""
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt != 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
@@ -365,12 +368,15 @@ class TestOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5)
|
||||
|
||||
def test_cmp_lt_backwards(self):
|
||||
# new grad zeroes these out
|
||||
"""
|
||||
t1 = torch.ones(4, requires_grad=True)
|
||||
t2 = torch.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (t1 < t2).sum().backward)
|
||||
tt1 = Tensor.ones(4, requires_grad=True)
|
||||
tt2 = Tensor.ones(4, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
|
||||
"""
|
||||
tt = Tensor.randn(4, requires_grad=True)
|
||||
(tt*(tt < 0)).sum().backward()
|
||||
t = torch.tensor(tt.numpy(), requires_grad=True)
|
||||
|
||||
@@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
|
||||
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
@@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
@@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(out, 2)
|
||||
|
||||
# multireduce spec
|
||||
@unittest.skip("these two Tensors are the same")
|
||||
def test_example_matmul(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
y = Tensor.eye(64, requires_grad=True)
|
||||
@@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_contig(self):
|
||||
x = Tensor.eye(64, requires_grad=True).contiguous().realize()
|
||||
y = Tensor.eye(64, requires_grad=True).contiguous().realize()
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_same(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
z = x.matmul(x).sum()
|
||||
@@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1060,7 +1070,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters(c1))
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 5)
|
||||
check_schedule(opt.schedule_step(), 3)
|
||||
|
||||
def test_sgd_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 8)
|
||||
check_schedule(opt.schedule_step(), 7)
|
||||
|
||||
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
||||
with Tensor.train():
|
||||
@@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 10)
|
||||
check_schedule(opt.schedule_step(), 9)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
@@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
check_schedule(opt.schedule_step(), 17)
|
||||
|
||||
def test_sgd_4convs_fuse_conv_bw(self):
|
||||
with Tensor.train():
|
||||
@@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase):
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
@@ -1844,7 +1854,7 @@ def run_tensor_ast(r:Tensor):
|
||||
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
|
||||
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ())
|
||||
run_schedule([si])
|
||||
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
|
||||
|
||||
@@ -1940,19 +1950,19 @@ class TestSwizzle(unittest.TestCase):
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
@unittest.skip("this swizzle can't be decided after the ADD")
|
||||
def test_swizzle_failure_permute(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
@@ -1971,13 +1981,13 @@ class TestSwizzle(unittest.TestCase):
|
||||
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
|
||||
x15,)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
x10,)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
@@ -2269,6 +2279,54 @@ class TestCopyFolding(unittest.TestCase):
|
||||
a = Tensor.empty(4).lazydata
|
||||
check_schedule(a.clone(), 1, filter_sink=False)
|
||||
|
||||
# NOTE: moving copy before view might change this
|
||||
def test_shrink_copy(self):
|
||||
a = Tensor.arange(4)
|
||||
view = a.shrink(((0, 2),))
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 2)
|
||||
self.assertListEqual(b.tolist(), [0, 1])
|
||||
|
||||
def test_expanded_copy(self):
|
||||
a = Tensor.arange(2)
|
||||
view = a.reshape(2, 1).expand(2, 2)
|
||||
b = view.clone()
|
||||
run_schedule(check_schedule(b, 2, filter_sink=False))
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 2)
|
||||
self.assertEqual(b.lazydata.size, 4)
|
||||
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
|
||||
|
||||
def test_permuted_copy(self):
|
||||
a = Tensor.arange(4)
|
||||
b = a.reshape(2, 2).permute(1, 0)
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_on_disk(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_after_shrink(self):
|
||||
a = Tensor.arange(5)
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
# NOTE: disk permute must come after COPY
|
||||
# TODO: this is wrong because of the permute
|
||||
@unittest.expectedFailure
|
||||
def test_permute_after_shrink_on_disk(self):
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
class TestTensorUOpSpec(unittest.TestCase):
|
||||
def test_const_must_be_unmasked(self):
|
||||
a = Tensor.ones((4, 4)).pad((2, 2))
|
||||
@@ -2377,6 +2435,17 @@ class TestContiguous(unittest.TestCase):
|
||||
b = a.expand((4, 4)).contiguous().contiguous()
|
||||
check_schedule(b, 1)
|
||||
|
||||
def test_view_does_not_realize(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4))
|
||||
check_schedule(b, 0)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 4)
|
||||
|
||||
def test_contiguous_view_realizes(self):
|
||||
a = Tensor.empty(4)
|
||||
b = a.expand((4, 4)).contiguous()
|
||||
check_schedule(b, 1)
|
||||
self.assertEqual(b.lazydata.base.buffer.size, 16)
|
||||
|
||||
class TestUOpBecome(unittest.TestCase):
|
||||
# the simplest case, if we create a new BUFFER for this UOp
|
||||
|
||||
@@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
b = a.clone()
|
||||
np.testing.assert_allclose(a.numpy(), b.numpy())
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_clone_with_shrink(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
a = Tensor.rand(16, 16)
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
b = a.shrink(((2, 10), None))
|
||||
self.assertIsNot(b.lazydata, b.clone().lazydata)
|
||||
def test_clone_with_shrink_realized(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
b = a.shrink(((2, 10), None)).clone()
|
||||
b.realize()
|
||||
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
|
||||
|
||||
def test_clone_with_grad(self):
|
||||
a = Tensor.rand(16, 16, requires_grad=True)
|
||||
@@ -763,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
def test_complex_backward(self):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
y = Tensor.rand(3, requires_grad=True)
|
||||
x = Tensor.rand(3, requires_grad=True).realize()
|
||||
y = Tensor.rand(3, requires_grad=True).realize()
|
||||
out = (x.relu() * y.sigmoid()).sum()
|
||||
self.assertEqual(out.lazydata.metadata.name, "sum")
|
||||
out.backward()
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
|
||||
|
||||
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
|
||||
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x]
|
||||
gf = jax.grad(f if jf is None else jf)
|
||||
|
||||
for val in [-5., -2.0, 0.0, 2.0, 5.]:
|
||||
@@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase):
|
||||
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
|
||||
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
|
||||
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
|
||||
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
|
||||
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y]))
|
||||
gx, gy = grads[x], grads[y]
|
||||
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))
|
||||
|
||||
|
||||
@@ -1,118 +1,104 @@
|
||||
from typing import Dict, List, Optional
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic
|
||||
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
|
||||
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
|
||||
|
||||
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
|
||||
rewrite(sink, pm, **kwargs)
|
||||
assert len(contexts) == 1
|
||||
assert len(contexts[0]) == 1
|
||||
k = get_metadata(keys, contexts)[0][0]
|
||||
g = get_details(*k)
|
||||
return g.uops[1:]
|
||||
# NOTE: VIZ tests always use the tracked PatternMatcher instance
|
||||
symbolic = TrackedPatternMatcher(symbolic.patterns)
|
||||
|
||||
class TestViz(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# clear the global context
|
||||
contexts.clear()
|
||||
keys.clear()
|
||||
_name_cnt.clear()
|
||||
self.tms = TRACK_MATCH_STATS.value
|
||||
TRACK_MATCH_STATS.value = 2
|
||||
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
|
||||
|
||||
def test_viz_simple(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")*1, lambda x:x),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a*1, pm)
|
||||
self.assertEqual(len(uops), 1)
|
||||
self.assertEqual(uops[0], a)
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic)
|
||||
test(a*1)
|
||||
ret = get_metadata(keys, contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
key, val = ret[0]
|
||||
self.assertEqual(key, "test_1")
|
||||
self.assertEqual(val[0]["match_count"], 1)
|
||||
|
||||
def test_rewrite_twice(self):
|
||||
pm = PatternMatcher([
|
||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
|
||||
])
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
uops = helper_test_viz(a+a, pm)
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[0], a*2)
|
||||
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
||||
def test_track_two_rewrites(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
def test(sink): return graph_rewrite(sink, symbolic)
|
||||
test((a+a)*1)
|
||||
ret = get_metadata(keys, contexts)
|
||||
key, val = ret[0]
|
||||
self.assertEqual(len(ret), 1) # one context
|
||||
self.assertEqual(len(val), 1) # one graph_rewrite call in context
|
||||
self.assertEqual(key, "test_1")
|
||||
self.assertEqual(val[0]["match_count"], 2) # two upats applied
|
||||
|
||||
def test_rewrite_with_ctx(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), ShapeTracker.from_shape((1, 1)).to_uop()))
|
||||
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), ShapeTracker.from_shape((1, 1)).to_uop()))
|
||||
def store_load(ctx:Dict[UOp, None], glbl, st) -> Optional[UOp]:
|
||||
if glbl in ctx: return None
|
||||
ctx[glbl] = None
|
||||
return UOp.store(glbl, ShapeTracker.from_shape(st.shape).to_uop())
|
||||
pm = PatternMatcher([
|
||||
(UPat.load(UPat(Ops.DEFINE_GLOBAL, name="glbl"), UPat.var("st")), store_load),
|
||||
])
|
||||
uops = helper_test_viz(a+b, pm, ctx={})
|
||||
self.assertEqual(len(uops), 2)
|
||||
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
|
||||
def test_track_multiple_calls_one_ctx(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
@track_rewrites(named=True)
|
||||
def test(a, b):
|
||||
a = graph_rewrite(a, symbolic)
|
||||
b = graph_rewrite(b, symbolic)
|
||||
test(a*1, a*5)
|
||||
ret = get_metadata(keys, contexts)
|
||||
key, val = ret[0]
|
||||
self.assertEqual(len(ret), 1) # one context
|
||||
self.assertEqual(len(val), 2) # two graph_rewrite calls in context
|
||||
self.assertEqual(key, "test_1")
|
||||
self.assertEqual(val[0]["match_count"], 1) # one rewrite for a*0
|
||||
self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5
|
||||
|
||||
def test_track_rewrites(self):
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites(named=True)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
do_rewrite(ld*1)
|
||||
do_rewrite(ld*2)
|
||||
def do_rewrite(x:UOp): return graph_rewrite(x, symbolic)
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 4)
|
||||
do_rewrite(a*1)
|
||||
do_rewrite(a*b)
|
||||
ret = get_metadata(keys, contexts)
|
||||
self.assertEqual(len(ret), 2)
|
||||
key, _, m = ret[0][0]
|
||||
key, m = ret[0]
|
||||
self.assertEqual(key, "do_rewrite_1")
|
||||
self.assertEqual(len(m.upats), 1)
|
||||
key, _, m = ret[1][0]
|
||||
self.assertEqual(m[0]["match_count"], 1)
|
||||
key, m = ret[1]
|
||||
self.assertEqual(key, "do_rewrite_2")
|
||||
self.assertEqual(len(m.upats), 0)
|
||||
self.assertEqual(m[0]["match_count"], 0)
|
||||
|
||||
def test_track_rewrites_with_exception(self):
|
||||
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
|
||||
@track_rewrites()
|
||||
def do_rewrite(x:UOp):
|
||||
x = graph_rewrite(x, simple) # NOTE: viz tracks this
|
||||
x = graph_rewrite(x, symbolic) # NOTE: viz tracks this
|
||||
raise Exception("test")
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
|
||||
with self.assertRaises(Exception): do_rewrite(ld*1)
|
||||
a = UOp.variable("a", 0, 10)
|
||||
with self.assertRaises(Exception): do_rewrite(a*1)
|
||||
ret = get_metadata(keys, contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
|
||||
def test_fold_const(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
graph = uop_to_json(a)
|
||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||
# NOTE: CONST UOps do not get nodes in the graph
|
||||
def test_dont_create_const_nodes(self):
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 4)
|
||||
self.assertEqual(len(uop_to_json(a*1)), 2)
|
||||
self.assertEqual(len(uop_to_json(a*b)), 3)
|
||||
|
||||
@unittest.skip("TODO: bring this back with better testing")
|
||||
def test_bottom_up_rewrite(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
n1 = a.sin()
|
||||
uop = n1.sin()
|
||||
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
|
||||
self.assertEqual(len(ret), 2)
|
||||
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
|
||||
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
|
||||
|
||||
def test_top_down_rewrite(self):
|
||||
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||
n1 = a.sin()
|
||||
uop = n1.sin()
|
||||
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
# if it wasn't bottom_up, it's rewritten once
|
||||
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
|
||||
a = UOp.variable("a", 0, 10)
|
||||
b = UOp.variable("b", 0, 10)
|
||||
c = UOp.variable("c", 0, 10)
|
||||
UOp.substitute(a+b, {a+b:c})
|
||||
ret = get_metadata(keys, contexts)
|
||||
self.assertEqual(len(ret), 1)
|
||||
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
|
||||
_, _, vals = ret[0][0]
|
||||
self.assertEqual(len(vals.upats), 1)
|
||||
|
||||
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
|
||||
def test_rewrite_without_context(self):
|
||||
@@ -211,6 +197,5 @@ class TextVizProfiler(unittest.TestCase):
|
||||
self.assertEqual(j['traceEvents'][7]['dur'], 4)
|
||||
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -3,14 +3,14 @@ from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Any, Iterator, Generator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution
|
||||
cpu_time_execution, colored, Context
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"]
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
@@ -25,7 +25,7 @@ class _Device:
|
||||
cpn = multiprocessing.current_process().name
|
||||
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
|
||||
x = ix.split(":")[0].upper()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
|
||||
if (cname.lower() == x.lower() + "device")][0](ix)
|
||||
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
||||
self._opened_devices.add(ix)
|
||||
@@ -33,7 +33,7 @@ class _Device:
|
||||
@property
|
||||
def default(self) -> Compiled: return self[self.DEFAULT]
|
||||
def get_available_devices(self) -> Iterator[str]:
|
||||
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
|
||||
for device in ALL_DEVICES:
|
||||
with contextlib.suppress(Exception): yield self[device].device
|
||||
@functools.cached_property
|
||||
def DEFAULT(self) -> str:
|
||||
@@ -220,9 +220,10 @@ MAP_JIT = 0x0800
|
||||
|
||||
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
|
||||
class CPUProgram:
|
||||
helper_handle = ctypes.CDLL(ctypes.util.find_library('System') if OSX else 'libgcc_s.so.1')
|
||||
|
||||
helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32' if sys.platform == "win32" else 'gcc_s'))
|
||||
def __init__(self, name:str, lib:bytes):
|
||||
assert sys.platform != "win32", "clang is not supported for windows yet"
|
||||
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
|
||||
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
|
||||
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
|
||||
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
|
||||
@@ -314,3 +315,18 @@ if PROFILE:
|
||||
|
||||
from tinygrad.ops import launch_viz
|
||||
launch_viz("PROFILE", fn)
|
||||
|
||||
if __name__ == "__main__":
|
||||
for device in ALL_DEVICES:
|
||||
try:
|
||||
_ = Device[device].device
|
||||
try:
|
||||
from tinygrad import Tensor
|
||||
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
|
||||
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
|
||||
result = colored("PASS", "green")
|
||||
except Exception as e:
|
||||
result = f"{colored('FAIL', 'yellow')} {e}"
|
||||
except Exception as e:
|
||||
result = f"{colored('FAIL', 'red')} {e}"
|
||||
print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}")
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, Variable, sym_infer
|
||||
from tinygrad.ops import UOp, Variable, sym_infer, Ops
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
@@ -194,7 +194,8 @@ def _prepare_jit_inputs(args, kwargs):
|
||||
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
|
||||
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
|
||||
if tensors: Tensor.realize(*tensors)
|
||||
lbs: list[UOp] = flatten([t.lazydata.lbs for t in tensors])
|
||||
# TODO: should we be unpacking multi here?
|
||||
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
|
||||
input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
|
||||
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
|
||||
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]
|
||||
|
||||
@@ -47,4 +47,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import sys, atexit, functools, pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
|
||||
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
|
||||
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
|
||||
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
|
||||
from tinygrad.dtype import DType, ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -37,7 +37,7 @@ tensor_uop_spec = PatternMatcher([
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
|
||||
|
||||
# COPY
|
||||
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
|
||||
@@ -60,7 +60,6 @@ class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
assign_preloads: tuple[UOp, ...]
|
||||
@property
|
||||
def outputs(self) -> tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
@@ -82,9 +81,8 @@ class ScheduleContext:
|
||||
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
|
||||
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
becomes_map: dict[UOp, UOp] = field(default_factory=dict)
|
||||
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
@@ -104,6 +102,7 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c
|
||||
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
|
||||
dtype = buf.dtype.base
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
@@ -152,9 +151,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
# push VIEW to stores
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view()
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
||||
@@ -200,6 +199,8 @@ to_si = PatternMatcher([
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
# once images are loaded they become the base dtype
|
||||
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# CONST(VIEW) becomes VALID too, TODO: doesn't have to
|
||||
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)),
|
||||
])
|
||||
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
@@ -211,25 +212,27 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
# deal with ASSIGN
|
||||
assign_preloads: list[UOp] = []
|
||||
if len(ctx.assigns) != 0:
|
||||
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
||||
for x in list(sink.toposort)[::-1]:
|
||||
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
||||
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
||||
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads.append(x.buf_uop)
|
||||
assign_preloads[x.buf_uop] = None
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
|
||||
# otherwise, it's not fine
|
||||
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
|
||||
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -329,7 +332,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
# maybe fuse arange with its children
|
||||
for rbuf in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
|
||||
if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue
|
||||
if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue
|
||||
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group: del ctx.realizes[tr]
|
||||
@@ -354,20 +357,20 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH is a NOOP here
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
@@ -389,10 +392,10 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER/VIEW from SINK
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
@@ -419,7 +422,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs)
|
||||
return x.view(unwrap(view.st))
|
||||
|
||||
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
|
||||
if not b.device.startswith("DISK"): return None
|
||||
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
|
||||
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
|
||||
|
||||
@@ -440,12 +443,12 @@ do_realize = PatternMatcher([
|
||||
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
|
||||
])
|
||||
|
||||
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
|
||||
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
|
||||
|
||||
def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp):
|
||||
assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}"
|
||||
ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg
|
||||
return ret.valid(unwrap(bind.st))
|
||||
assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}"
|
||||
ctx.var_vals[var.replace(src=())] = val.const_arg
|
||||
return var
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
@@ -458,8 +461,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
|
||||
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
break_sched = PatternMatcher([
|
||||
# CONST is always fused and generated
|
||||
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
|
||||
(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable),
|
||||
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
|
||||
@@ -473,38 +474,43 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
buf_uop.buffer.ref(1)
|
||||
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
|
||||
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = PatternMatcher([
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
# merge one src views.
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
||||
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
|
||||
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: becomes_map[k] = v
|
||||
|
||||
# we group the rest of UOps into ScheduleItems
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
|
||||
# add realizes
|
||||
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
|
||||
# group realizes into kernels
|
||||
store_groups = group_realizes(ctx)
|
||||
graph_rewrite(sink, break_sched, ctx)
|
||||
# preschedule realize groups
|
||||
# create schedule items + map buffers to realized tensors
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
|
||||
@@ -512,16 +518,9 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
prescheduled.append(schedule_uop(small_sink, ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
|
||||
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
|
||||
for k,v in tensor_map.items():
|
||||
# NOOP
|
||||
if k.base is v.base: continue
|
||||
# NOTE: only the base tensors get a BUFFER UOp
|
||||
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
|
||||
# otherwise if it simplified to a CONST the UOp just becomes that CONST
|
||||
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
|
||||
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
# increment refcount for this buffer
|
||||
buf_uop.buffer.ref(1)
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
@@ -529,7 +528,7 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
||||
for si in prescheduled:
|
||||
# realize outputs before a parent is assigned to
|
||||
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
|
||||
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
|
||||
for assign in parents_assigns:
|
||||
graph[si].append(assign)
|
||||
in_degree[assign] += 1
|
||||
@@ -550,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
|
||||
# confirm everything was scheduled correctly
|
||||
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
return schedule, ctx.var_vals, ctx.becomes_map
|
||||
return schedule, ctx.var_vals, becomes_map
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
"""This is where the forwards and backwards passes live."""
|
||||
import math
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import Ops, resolve, sint, UOp
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x.contiguous()
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output
|
||||
|
||||
class ContiguousBackward(Function):
|
||||
def forward(self, x:UOp) -> UOp: return x
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous()
|
||||
|
||||
class Cast(Function):
|
||||
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp:
|
||||
self.input_dtype, self.bitcast = x.dtype, bitcast
|
||||
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp:
|
||||
if self.bitcast: raise RuntimeError("bitcast cannot backward")
|
||||
return grad_output.cast(self.input_dtype)
|
||||
|
||||
# ************* unary ops *************
|
||||
|
||||
class Reciprocal(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.ret = x.reciprocal()
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret
|
||||
|
||||
class Sin(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.x = x
|
||||
return x.sin()
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output
|
||||
|
||||
class Relu(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.ret = (x>0).where(x, 0)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output
|
||||
|
||||
class Log(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.x = x
|
||||
return x.log2() * math.log(2)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.ret = (x * (1/math.log(2))).exp2()
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output
|
||||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:UOp) -> UOp:
|
||||
self.ret = x.sqrt()
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2)
|
||||
|
||||
class Sign(Function):
|
||||
# NOTE: the x*0 is to match torch behavior without function.py
|
||||
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
|
||||
# backward always return 0 to match torch
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0)
|
||||
|
||||
# ************* binary ops *************
|
||||
|
||||
class Less(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
|
||||
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
|
||||
|
||||
class Neq(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
|
||||
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
|
||||
|
||||
class BitwiseAnd(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
|
||||
|
||||
class BitwiseOr(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
|
||||
|
||||
class Threefry(Function):
|
||||
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
|
||||
|
||||
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp:
|
||||
self.x, self.y = x, y
|
||||
return x * y
|
||||
|
||||
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
|
||||
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
|
||||
(self.x * grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class IDiv(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
|
||||
|
||||
class Mod(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
class Where(Function):
|
||||
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp:
|
||||
self.x = x
|
||||
return self.x.where(y, z)
|
||||
|
||||
def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]:
|
||||
return None, \
|
||||
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
|
||||
self.input_shape = x.shape
|
||||
return x.r(Ops.ADD, axis)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape)
|
||||
|
||||
class Prod(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
|
||||
self.x, self.ret = x, x.r(Ops.MUL, axis)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp:
|
||||
return (grad_output * self.ret).expand(self.x.shape) / self.x
|
||||
|
||||
class Max(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
|
||||
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp:
|
||||
# 1s in locations where the max was chosen (can be two locations)
|
||||
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
|
||||
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
|
||||
return (max_is_1s/div) * grad_output.expand(self.x.shape)
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
class Expand(Function):
|
||||
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
|
||||
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
|
||||
return x.expand(shape)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp:
|
||||
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
|
||||
|
||||
class Reshape(Function):
|
||||
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
|
||||
self.input_shape = x.shape
|
||||
return x.reshape(shape)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape)
|
||||
|
||||
class Permute(Function):
|
||||
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp:
|
||||
self.input_order = order
|
||||
return x.permute(order)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order))
|
||||
|
||||
class Pad(Function):
|
||||
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp:
|
||||
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
|
||||
return x.pad(arg)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg)
|
||||
|
||||
class Shrink(Function):
|
||||
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp:
|
||||
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
|
||||
return x.shrink(arg)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg)
|
||||
|
||||
class Flip(Function):
|
||||
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
|
||||
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
|
||||
return x.stride(self.arg)
|
||||
|
||||
def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg)
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast, Iterator
|
||||
import math, functools
|
||||
import math, functools, dataclasses
|
||||
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
|
||||
from tinygrad.helpers import argsort
|
||||
|
||||
def reduce_gradient(ctx:UOp, ret:UOp):
|
||||
@@ -28,6 +28,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
|
||||
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
|
||||
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
|
||||
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
|
||||
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
|
||||
@@ -36,7 +37,7 @@ pm_gradient = PatternMatcher([
|
||||
# TODO: this cast can be removed by putting the casts around the EXPAND
|
||||
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
|
||||
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
|
||||
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# there's no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
||||
])
|
||||
@@ -65,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
|
||||
if v is None: continue
|
||||
if k in grads: grads[k] = grads[k] + v
|
||||
else: grads[k] = v
|
||||
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
|
||||
return grads
|
||||
|
||||
@@ -109,6 +109,7 @@ USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT",
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
CACHELEVEL = ContextVar("CACHELEVEL", 2)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
@@ -165,7 +166,6 @@ class Profiling(contextlib.ContextDecorator):
|
||||
|
||||
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
|
||||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 17
|
||||
_db_connection = None
|
||||
@@ -186,7 +186,7 @@ def diskcache_clear():
|
||||
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
|
||||
|
||||
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
||||
if CACHELEVEL == 0: return None
|
||||
if CACHELEVEL < 1: return None
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
@@ -199,7 +199,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
|
||||
|
||||
_db_tables = set()
|
||||
def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
|
||||
if CACHELEVEL == 0: return val
|
||||
if CACHELEVEL < 1: return val
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import Ops, MathTrait, UOp, sint
|
||||
from tinygrad.ops import Ops, UOp, sint
|
||||
|
||||
def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
||||
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
|
||||
@@ -40,133 +39,127 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) ->
|
||||
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
||||
|
||||
class MultiLazyBuffer(MathTrait):
|
||||
def __init__(self, lbs:list[UOp], axis:int|None, real:list[bool]|None=None):
|
||||
assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
||||
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
||||
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
||||
# ***** multi functions *****
|
||||
|
||||
@property
|
||||
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
||||
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
|
||||
|
||||
@property
|
||||
def size(self): return sum(x.size for x in self.real_lbs)
|
||||
def alu_multi(root:UOp):
|
||||
msrcs = root.src
|
||||
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
|
||||
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
||||
|
||||
@property
|
||||
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
||||
srcs:list[list[UOp]] = []
|
||||
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
||||
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
|
||||
for mlb in msrcs:
|
||||
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
|
||||
else:
|
||||
assert axis is not None and bounds is not None
|
||||
if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
|
||||
new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
|
||||
new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
|
||||
return UOp.multi(*new_lbs, axis=axis, real=new_real)
|
||||
|
||||
@property
|
||||
def bounds(self):
|
||||
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
||||
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0)))
|
||||
def reduce_multi(root:UOp, multi:UOp):
|
||||
op, axis = root.arg
|
||||
if multi.axis is not None and multi.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
|
||||
# if all partitions are real, do all_reduce
|
||||
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None)
|
||||
# only one partition is real, keep it
|
||||
return UOp.multi(*reduced_parts, axis=None, real=multi.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
||||
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
||||
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
|
||||
|
||||
def copy_to_device(self, device:str) -> UOp:
|
||||
# if we already have a copy on the device, return that
|
||||
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:list[UOp] = []
|
||||
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||
if not real: continue
|
||||
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
||||
llbs.append(lb.copy_to_device(device).pad(pad_arg))
|
||||
return functools.reduce(operator.add, llbs)
|
||||
def reshape_multi(root:UOp, multi:UOp):
|
||||
arg = root.arg
|
||||
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real)
|
||||
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
|
||||
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1
|
||||
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
|
||||
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
|
||||
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
|
||||
return UOp.multi(*lbs, axis=new_axis, real=multi.real)
|
||||
|
||||
# passthroughs
|
||||
@property
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
||||
def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real)
|
||||
def bitcast(self, dtype:DType): return MultiLazyBuffer([x.bitcast(dtype) for x in self.lbs], self.axis, self.real)
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
||||
def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real)
|
||||
@property
|
||||
def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort}
|
||||
def expand_multi(root:UOp, multi:UOp):
|
||||
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
||||
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
|
||||
return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
# elementwise is simple
|
||||
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
||||
msrcs = (self,)+in_srcs
|
||||
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
|
||||
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
|
||||
def pad_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
|
||||
# pad on shard axis -> fill others with zeros and set real to all True
|
||||
if multi.axis is not None and root.arg[multi.axis] != (0,0):
|
||||
# pad back to whole axis, remove real mask
|
||||
assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
|
||||
dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
|
||||
assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
||||
return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
|
||||
return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
# NOTE: they all have to share an axis, we always choose [-1]
|
||||
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
|
||||
srcs:list[list[UOp]] = []
|
||||
not_all_real = not all(all(mlb.real) for mlb in msrcs)
|
||||
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
|
||||
assert any(new_real), "output contains no real lb"
|
||||
for mlb in msrcs:
|
||||
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
||||
else:
|
||||
assert axis is not None and bounds is not None
|
||||
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
|
||||
def permute_multi(root:UOp, multi:UOp):
|
||||
# all permutes supported!
|
||||
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real)
|
||||
|
||||
def r(self, op:Ops, axis:tuple[int, ...]) -> MultiLazyBuffer:
|
||||
if self.axis is not None and self.axis in axis:
|
||||
# all-reduce on sharded axes
|
||||
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
|
||||
# if all partitions are real, do all_reduce
|
||||
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
|
||||
# only one partition is real, keep it
|
||||
return MultiLazyBuffer(reduced_parts, None, self.real)
|
||||
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
|
||||
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
|
||||
def shrink_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
|
||||
f"shrinking not supported for {root.arg=}"
|
||||
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
|
||||
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
|
||||
"cannot shrink sharded and non-sharded axis at the same time"
|
||||
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
||||
idx = multi.bounds.index(root.arg[multi.axis])
|
||||
# zero out other lbs to not create lb reference
|
||||
return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
|
||||
axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
|
||||
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
|
||||
axis=multi.axis, real=multi.real)
|
||||
|
||||
# *** movement ops ***
|
||||
def stride_multi(root:UOp, multi:UOp):
|
||||
assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis"
|
||||
return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
def _shape_to_single_shard(self, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
|
||||
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
|
||||
def copy_multi(multi:UOp, device:UOp):
|
||||
# if we already have a copy on the device, return that
|
||||
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:list[UOp] = []
|
||||
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
|
||||
if not real: continue
|
||||
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
|
||||
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
|
||||
return functools.reduce(operator.add, llbs)
|
||||
|
||||
def reshape(self, arg:tuple[sint, ...]):
|
||||
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
|
||||
assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
|
||||
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
|
||||
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
|
||||
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
|
||||
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
|
||||
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
|
||||
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
|
||||
return MultiLazyBuffer(lbs, new_axis, self.real)
|
||||
def assign_multi(dest:UOp, src:UOp):
|
||||
assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
|
||||
return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
|
||||
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
|
||||
# pad on shard axis -> fill others with zeros and set real to all True
|
||||
if self.axis is not None and arg[self.axis] != (0,0):
|
||||
# pad back to whole axis, remove real mask
|
||||
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
|
||||
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
|
||||
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
|
||||
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
|
||||
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
|
||||
def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
|
||||
|
||||
def expand(self, arg:tuple[sint, ...]):
|
||||
# NOTE: this assert isn't needed, sharded axis can have dim 1
|
||||
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
|
||||
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
|
||||
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
|
||||
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
|
||||
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
|
||||
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
|
||||
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
|
||||
(UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
|
||||
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
|
||||
])
|
||||
|
||||
def permute(self, arg:tuple[int, ...]):
|
||||
# all permutes supported!
|
||||
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
|
||||
|
||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]):
|
||||
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
|
||||
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
|
||||
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
|
||||
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
|
||||
idx = self.bounds.index(arg[self.axis])
|
||||
# zero out other lbs to not create lb reference
|
||||
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
|
||||
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
|
||||
self.axis, self.real)
|
||||
|
||||
def stride(self, arg:tuple[int, ...]):
|
||||
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
|
||||
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
|
||||
@track_rewrites(named=True)
|
||||
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}
|
||||
|
||||
@@ -77,9 +77,6 @@ class LARS(Optimizer):
|
||||
|
||||
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
|
||||
for i, (t, g) in enumerate(zip(self.params, grads)):
|
||||
# contiguous is needed since the grads can allegedly form a "diamond"
|
||||
# TODO: fix this in lazy.py
|
||||
g = g.contiguous()
|
||||
if self.tcoef != 0:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = g.square().sum().sqrt()
|
||||
|
||||
@@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
|
||||
class TensorIO(io.RawIOBase, BinaryIO):
|
||||
def __init__(self, t: Tensor):
|
||||
@@ -152,9 +151,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
|
||||
continue
|
||||
if v.shape != state_dict[k].shape:
|
||||
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
|
||||
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
|
||||
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
|
||||
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
|
||||
if isinstance(v.device, tuple):
|
||||
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
|
||||
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
|
||||
107
tinygrad/ops.py
107
tinygrad/ops.py
@@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait):
|
||||
# the order of these Ops controls the order of the toposort
|
||||
class Ops(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto(); CONTIGUOUS = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
|
||||
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
|
||||
|
||||
# TODO: empty continues to exist because of tensor
|
||||
EMPTY = auto()
|
||||
@@ -150,6 +150,7 @@ class Ops(FastEnum):
|
||||
|
||||
# device
|
||||
DEVICE = auto()
|
||||
MULTI = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
@@ -281,6 +282,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
if self.op is Ops.MULTI:
|
||||
return ShapeTracker.from_shape(
|
||||
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
|
||||
# these ops define a ShapeTracker from the arg
|
||||
if self.op is Ops.VIEW: return self.arg
|
||||
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
|
||||
@@ -294,7 +299,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# only reduce ops are allowed to change shape, everything else derives shape from sources
|
||||
elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
|
||||
else: shape = src_sts[0].shape
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(shape)
|
||||
|
||||
@functools.cached_property
|
||||
@@ -350,7 +354,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def const_like(self, b:ConstLike):
|
||||
# constants can optionally have a DEVICE source
|
||||
return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
||||
if self._device is None: return UOp.const(self.dtype, b)
|
||||
if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
|
||||
return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
if count == 1: return self
|
||||
@@ -389,7 +395,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
new_shape = unwrap(self.st).reduce(axis)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
||||
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
|
||||
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
|
||||
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, axis)
|
||||
|
||||
@@ -409,6 +416,46 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
|
||||
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
|
||||
|
||||
# *** from MultiLazyBuffer ***
|
||||
|
||||
def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
|
||||
parents = (self,)+more
|
||||
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
|
||||
return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
|
||||
|
||||
@property
|
||||
def bounds(self):
|
||||
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
||||
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
|
||||
|
||||
@property
|
||||
def axis(self):
|
||||
assert self.op is Ops.MULTI
|
||||
return self.arg[0]
|
||||
|
||||
@property
|
||||
def real(self):
|
||||
assert self.op is Ops.MULTI
|
||||
return self.arg[1]
|
||||
|
||||
@property
|
||||
def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
|
||||
if axis is None: lbs = [self] * len(devices)
|
||||
else:
|
||||
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
||||
sz = self.shape[axis] // len(devices)
|
||||
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
||||
lbs, off = [], 0
|
||||
for sz in sizes:
|
||||
lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
|
||||
off += sz
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
||||
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
|
||||
return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
@@ -426,7 +473,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
|
||||
# otherwise it's just a VIEW(BUFFER)
|
||||
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
|
||||
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
|
||||
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
|
||||
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
|
||||
@@ -440,8 +487,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
|
||||
@property
|
||||
def lbs(self): return [self]
|
||||
@property
|
||||
def metadata(self): return all_metadata.get(self, None)
|
||||
|
||||
# *** uop movement ops ***
|
||||
@@ -470,10 +515,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@staticmethod
|
||||
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
|
||||
@property
|
||||
def device(self) -> str: return unwrap(self._device)
|
||||
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
|
||||
@functools.cached_property
|
||||
def _device(self) -> Optional[str]:
|
||||
def _device(self) -> Optional[str|tuple[str, ...]]:
|
||||
if self.op is Ops.DEVICE: return self.arg
|
||||
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
@@ -489,6 +535,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
from tinygrad.device import Buffer
|
||||
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
|
||||
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
|
||||
return ret
|
||||
@property
|
||||
@@ -496,7 +543,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized
|
||||
return self.buffer if self.op is Ops.BUFFER else None
|
||||
@property
|
||||
def is_realized(self) -> bool: return self.base.realized is not None
|
||||
def is_realized(self) -> bool:
|
||||
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@@ -639,7 +687,7 @@ def print_uops(uops:list[UOp]):
|
||||
def get_location() -> tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py",
|
||||
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
|
||||
"lowerer.py", "cstyle.py", "linearize.py"}:
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
@@ -775,9 +823,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
|
||||
match_stats:dict[UPat, list[Union[int, float]]] = dict()
|
||||
@dataclass(frozen=True)
|
||||
class TrackedGraphRewrite:
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink input to graph_rewrite
|
||||
matches: list[tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # before+after of all the matches
|
||||
loc: tuple[str, int] # location that called graph_rewrite
|
||||
sink: UOp # the sink input to graph_rewrite
|
||||
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
|
||||
tracked_keys:list[Any] = []
|
||||
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
|
||||
_name_cnt:dict[str, int] = {}
|
||||
@@ -808,10 +856,9 @@ class TrackedPatternMatcher(PatternMatcher):
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][3] += (et:=time.perf_counter()-st)
|
||||
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p, et))
|
||||
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
|
||||
return ret # NOTE: if it returns None, we keep trying to match
|
||||
match_stats[p][2] += time.perf_counter()-st
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0 and len(tracked_ctxs[-1]) != 0: tracked_ctxs[-1][-1].matches.append((uop, None, None, 0))
|
||||
return None
|
||||
|
||||
if TRACK_MATCH_STATS:
|
||||
@@ -849,7 +896,7 @@ class RewriteContext:
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
def top_down_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_src = tuple(map(self.top_down_rewrite, n.src))
|
||||
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
||||
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
|
||||
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
|
||||
return ret
|
||||
@@ -857,7 +904,7 @@ class RewriteContext:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_n: UOp|None = n
|
||||
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
|
||||
new_src = tuple(map(self.bottom_up_rewrite, last_n.src))
|
||||
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
|
||||
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
||||
return ret
|
||||
|
||||
@@ -1269,18 +1316,24 @@ Variable = UOp
|
||||
|
||||
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
||||
|
||||
# *** uop swizzling ***
|
||||
# *** UOp merge views and swizzling ***
|
||||
|
||||
merge_views = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)),
|
||||
(UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None),
|
||||
# VIEW(VIEW) merges to a single VIEW
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
||||
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
||||
])
|
||||
|
||||
# push VIEW to loads
|
||||
# push VIEW to parents
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
# VIEW(CONST) becomes VALID
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)),
|
||||
# VIEW before elementwise/buffer ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
||||
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
||||
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
@@ -198,12 +198,12 @@ class AMDCopyQueue(HWQueue):
|
||||
return self
|
||||
|
||||
def bind(self, dev:AMDDevice):
|
||||
if not dev.driverless: return
|
||||
if not getenv("AMD_SDMA_BIND", 0) or not dev.driverless: return
|
||||
|
||||
self.binded_device = dev
|
||||
self.hw_page = dev.allocator.alloc((qsz:=round_up(len(self._q), 8)) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True))
|
||||
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
|
||||
for i, value in enumerate(self._q): hw_view[i] = value
|
||||
for i in range(qsz): hw_view[i] = self._q[i] if i < len(self._q) else 0
|
||||
|
||||
self.indirect_cmd = [amd_gpu.SDMA_OP_INDIRECT | amd_gpu.SDMA_PKT_INDIRECT_HEADER_VMID(0), *data64_le(self.hw_page.va_addr), qsz, *data64_le(0)]
|
||||
self._q, self.cmd_sizes = hw_view, [len(self.indirect_cmd)]
|
||||
@@ -464,7 +464,7 @@ class PCIIface:
|
||||
|
||||
iommu_group = HWInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1]
|
||||
except OSError:
|
||||
if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (not inserted or no-iommu mode is not supported).")
|
||||
if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (run `sudo modprobe vfio-pci`).")
|
||||
PCIIface.vfio = False
|
||||
|
||||
# Init vfio for the device
|
||||
@@ -487,21 +487,18 @@ class PCIIface:
|
||||
self.pagemap = HWInterface("/proc/self/pagemap", os.O_RDONLY)
|
||||
self.bar_fds = {bar: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
|
||||
|
||||
self.adev = AMDev(self.pcidev, self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.adev = AMDev(self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
|
||||
self.doorbell_cpu_addr = mv_address(dbell)
|
||||
|
||||
libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_read_u16(self.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
|
||||
libpciaccess.pci_device_cfg_write_u16(self.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
|
||||
|
||||
# TODO: this is for 7900xtx, the only tested card.
|
||||
self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32,
|
||||
'array_count': 12, 'simd_arrays_per_engine': 2, 'lds_size_in_kb': 64}
|
||||
|
||||
def _map_pci_range(self, bar, off=0, addr=0, size=None):
|
||||
if PCIIface.vfio:
|
||||
vfio.VFIO_DEVICE_GET_REGION_INFO(self.vfio_dev, reg:=vfio.struct_vfio_region_info(argsz=ctypes.sizeof(vfio.struct_vfio_region_info), index=bar))
|
||||
fd, sz, off = self.vfio_dev, size or reg.size, reg.offset + off
|
||||
else: fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size
|
||||
fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size
|
||||
return to_mv(fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz)
|
||||
|
||||
def alloc(self, size:int, host=False, uncached=False, cpu_access=False):
|
||||
|
||||
@@ -102,19 +102,16 @@ class AMFirmware:
|
||||
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
|
||||
|
||||
class AMPageTableEntry:
|
||||
def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
|
||||
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
|
||||
|
||||
def set_table(self, entry_id, pte:AMPageTableEntry, valid=True):
|
||||
self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
|
||||
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
||||
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
|
||||
|
||||
def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True):
|
||||
f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \
|
||||
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if self.lv != am.AMDGPU_VM_PTB else 0) \
|
||||
f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
|
||||
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
|
||||
| ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
|
||||
| (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
|
||||
self.view[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
|
||||
|
||||
def get_entry(self, entry_id): return self.view[entry_id]
|
||||
self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
|
||||
|
||||
class AMPageTableTraverseContext:
|
||||
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
|
||||
@@ -126,22 +123,23 @@ class AMPageTableTraverseContext:
|
||||
|
||||
def level_down(self):
|
||||
pt, pte_idx, _ = self.pt_stack[-1]
|
||||
if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID:
|
||||
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
||||
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
||||
else:
|
||||
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
|
||||
assert self.create_pts, "Not allowed to create new page table"
|
||||
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
|
||||
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
|
||||
entry = pt.entries[pte_idx]
|
||||
|
||||
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
|
||||
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
|
||||
|
||||
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
|
||||
return self.pt_stack[-1]
|
||||
|
||||
def _try_free_pt(self) -> bool:
|
||||
pt, _, _ = self.pt_stack[-1]
|
||||
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
||||
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
|
||||
self.adev.mm.pfree(pt.paddr)
|
||||
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
|
||||
parent_pt.set_page(parent_pte_idx, 0x0, valid=False)
|
||||
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -156,7 +154,7 @@ class AMPageTableTraverseContext:
|
||||
if self.create_pts:
|
||||
while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
|
||||
else:
|
||||
while pt.lv!=am.AMDGPU_VM_PTB and (pt.get_entry(pte_idx)&am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
|
||||
while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
|
||||
|
||||
entries = min(size // pte_covers, 512 - pte_idx)
|
||||
assert entries > 0, "Invalid entries"
|
||||
@@ -173,7 +171,7 @@ class AMMemoryManager:
|
||||
self.adev, self.vram_size = adev, vram_size
|
||||
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
||||
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
|
||||
@@ -181,10 +179,10 @@ class AMMemoryManager:
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
|
||||
for paddr, psize in paddrs:
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
||||
frag = 0 if pte_covers == 0x1000 else 0x9
|
||||
for pte_off in range(pte_cnt):
|
||||
assert pt.get_entry(pte_idx + pte_off) & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.get_entry(pte_idx + pte_off):#x}"
|
||||
pt.set_page(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, frag=frag, valid=True)
|
||||
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
|
||||
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
|
||||
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
|
||||
|
||||
# Invalidate TLB after mappings.
|
||||
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
|
||||
@@ -197,8 +195,8 @@ class AMMemoryManager:
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
|
||||
for pte_id in range(pte_idx, pte_idx + pte_cnt):
|
||||
assert pt.get_entry(pte_id) & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.get_entry(pte_id):#x}"
|
||||
pt.set_page(pte_id, paddr=0x0, valid=False)
|
||||
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
|
||||
pt.set_entry(pte_id, paddr=0x0, valid=False)
|
||||
|
||||
@staticmethod
|
||||
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
|
||||
@@ -233,8 +231,8 @@ class AMMemoryManager:
|
||||
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
|
||||
|
||||
class AMDev:
|
||||
def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.pcidev, self.devfmt = pcidev, devfmt
|
||||
def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
|
||||
self.devfmt = devfmt
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
||||
@@ -258,8 +256,8 @@ class AMDev:
|
||||
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
|
||||
# all blocks that are initialized only during the initial AM boot.
|
||||
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
|
||||
self.is_booting = True # During boot only boot memory can be allocated. This flag is to validate this.
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1)
|
||||
self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
|
||||
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
|
||||
|
||||
# Memory manager & firmware
|
||||
self.mm = AMMemoryManager(self, self.vram_size)
|
||||
|
||||
@@ -25,6 +25,9 @@ class AM_GMC(AM_IP):
|
||||
self.vm_base = self.adev.mm.va_allocator.base
|
||||
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
|
||||
|
||||
# GFX11 has 44-bit address space
|
||||
self.address_space_mask = (1 << 44) - 1
|
||||
|
||||
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
|
||||
self.hub_initted = {"MM": False, "GC": False}
|
||||
@@ -99,7 +102,13 @@ class AM_GMC(AM_IP):
|
||||
if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}")
|
||||
|
||||
class AM_SMU(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def init(self):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
|
||||
|
||||
for clck in [0x00000C94, 0x000204E1, 0x000105DC, 0x00050B76, 0x00070B76, 0x00040898, 0x00060898, 0x000308FD]:
|
||||
@@ -115,6 +124,11 @@ class AM_SMU(AM_IP):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
def read_table(self, table_t, cmd):
|
||||
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
|
||||
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
|
||||
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
|
||||
|
||||
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
||||
def _smu_cmn_send_msg(self, msg, param=0):
|
||||
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
|
||||
@@ -237,8 +251,9 @@ class AM_IH(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.ring_size = 512 << 10
|
||||
self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
|
||||
(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
|
||||
def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True),
|
||||
self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True))
|
||||
self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)]
|
||||
|
||||
def interrupt_handler(self):
|
||||
_, rwptr_vm, suf, _ = self.rings[0]
|
||||
@@ -315,6 +330,9 @@ class AM_PSP(AM_IP):
|
||||
self.ring_size = 0x10000
|
||||
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
self.max_tmr_size = 0x1300000
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
|
||||
def init(self):
|
||||
sos_components_load_order = [
|
||||
@@ -359,7 +377,7 @@ class AM_PSP(AM_IP):
|
||||
# Load TOC and calculate TMR size
|
||||
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
|
||||
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
|
||||
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
|
||||
assert self.tmr_size <= self.max_tmr_size
|
||||
|
||||
def _ring_create(self):
|
||||
# If the ring is already created, destroy it
|
||||
|
||||
@@ -115,8 +115,9 @@ class ShapeTracker:
|
||||
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import get_multi_map
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.device import Device, BufferSpec
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
@@ -30,45 +30,23 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
||||
|
||||
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
|
||||
# NOTE: this uses all_tensors, but it's fast
|
||||
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
|
||||
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
|
||||
|
||||
if len(fixed_tensors):
|
||||
# potentially rewrite all the discovered Tensors
|
||||
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
|
||||
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
|
||||
new_sink = sink.substitute(applied_map)
|
||||
|
||||
# set the relevant lazydata to the realized UOps
|
||||
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
|
||||
if s is ns: continue
|
||||
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
|
||||
else: t.lazydata = ns
|
||||
t.lazydata = ns
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
|
||||
self.device = device
|
||||
self.needs_input_grad = [t.requires_grad for t in tensors]
|
||||
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
|
||||
if self.requires_grad: self.parents = tensors
|
||||
self.metadata = metadata
|
||||
|
||||
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
|
||||
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
|
||||
|
||||
@classmethod
|
||||
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
|
||||
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
|
||||
ret = Tensor.__new__(Tensor)
|
||||
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
||||
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
||||
return ret
|
||||
|
||||
import tinygrad.function as F
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None)
|
||||
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
||||
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
@@ -148,8 +126,7 @@ class Tensor(SimpleMathTrait):
|
||||
np.set_printoptions(precision=4)
|
||||
```
|
||||
"""
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
__slots__ = "lazydata", "requires_grad", "grad"
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
@@ -159,7 +136,7 @@ class Tensor(SimpleMathTrait):
|
||||
return instance
|
||||
def __del__(self): all_tensors.discard(weakref.ref(self))
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
|
||||
@@ -172,11 +149,8 @@ class Tensor(SimpleMathTrait):
|
||||
# None (the default) will be updated to True if it's put in an optimizer
|
||||
self.requires_grad: Optional[bool] = requires_grad
|
||||
|
||||
# internal variable used for autograd graph construction
|
||||
self._ctx: Optional[Function] = None
|
||||
|
||||
# create a LazyBuffer from the different types of inputs
|
||||
if isinstance(data, (UOp, MultiLazyBuffer)):
|
||||
if isinstance(data, UOp):
|
||||
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
|
||||
# NOTE: this is here because LazyBuffer = UOp
|
||||
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
|
||||
@@ -199,12 +173,12 @@ class Tensor(SimpleMathTrait):
|
||||
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
|
||||
|
||||
# by this point, it has to be a LazyBuffer
|
||||
if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
|
||||
|
||||
# data might be on a different device
|
||||
if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
|
||||
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
|
||||
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
||||
elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata
|
||||
elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
|
||||
else:
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata = data
|
||||
@@ -224,8 +198,8 @@ class Tensor(SimpleMathTrait):
|
||||
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{ld!r}"
|
||||
else: ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
||||
ld = self.lazydata
|
||||
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
|
||||
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
|
||||
|
||||
# Python has a non moving GC, so this should be okay
|
||||
@@ -246,6 +220,17 @@ class Tensor(SimpleMathTrait):
|
||||
@property
|
||||
def dtype(self) -> DType: return self.lazydata.dtype
|
||||
|
||||
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
|
||||
ret = Tensor.__new__(Tensor)
|
||||
needs_input_grad = [t.requires_grad for t in (self,)+x]
|
||||
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
|
||||
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
|
||||
return ret
|
||||
|
||||
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
lhs,rhs = self._broadcasted(x, reverse)
|
||||
return lhs._apply_uop(fxn, rhs)
|
||||
|
||||
# ***** data handlers ****
|
||||
|
||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
@@ -254,7 +239,14 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
||||
|
||||
# TODO: move this to scheduler tensor_map pass
|
||||
if any(x.op is Ops.MULTI for x in big_sink.toposort):
|
||||
# multi fixup
|
||||
_apply_map_to_tensors(get_multi_map(big_sink))
|
||||
big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
|
||||
|
||||
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map)
|
||||
return memory_planner(schedule), var_vals
|
||||
@@ -275,7 +267,6 @@ class Tensor(SimpleMathTrait):
|
||||
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
||||
"""
|
||||
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||
assert getattr(self, '_ctx', None) is None
|
||||
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.lazydata = x.lazydata
|
||||
return self
|
||||
@@ -287,13 +278,11 @@ class Tensor(SimpleMathTrait):
|
||||
self.contiguous().realize().lazydata.base.realized.copyin(x._data())
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
|
||||
# NOTE: we allow cross device assign
|
||||
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if not self.lazydata.is_realized: return self.replace(x)
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
@@ -309,7 +298,8 @@ class Tensor(SimpleMathTrait):
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
||||
buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized)
|
||||
buf = cast(UOp, cpu.lazydata).base.realized
|
||||
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
||||
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
||||
|
||||
@@ -373,7 +363,6 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
|
||||
if self.grad is not None: ret.grad = self.grad.clone()
|
||||
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
||||
return ret
|
||||
|
||||
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
|
||||
@@ -385,7 +374,6 @@ class Tensor(SimpleMathTrait):
|
||||
if not isinstance(device, str): return self.shard(device)
|
||||
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
if hasattr(self, '_ctx'): ret._ctx = self._ctx
|
||||
return ret
|
||||
|
||||
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
|
||||
@@ -405,18 +393,9 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.shard((t.device, t.device), axis=1).lazydata)
|
||||
```
|
||||
"""
|
||||
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
|
||||
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
|
||||
devices = tuple(Device.canonicalize(x) for x in devices)
|
||||
if axis is None: lbs = [self.lazydata] * len(devices)
|
||||
else:
|
||||
axis = self._resolve_dim(axis)
|
||||
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
|
||||
sz = self.shape[axis] // len(devices)
|
||||
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
||||
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
||||
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
|
||||
mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis)
|
||||
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
||||
@@ -439,7 +418,7 @@ class Tensor(SimpleMathTrait):
|
||||
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
|
||||
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
||||
if isinstance(device, tuple):
|
||||
return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
|
||||
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
||||
device, dtype, **kwargs)
|
||||
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
||||
|
||||
@@ -510,7 +489,7 @@ class Tensor(SimpleMathTrait):
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
return counts0.cat(counts1)
|
||||
|
||||
@@ -750,12 +729,12 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", self.dtype)
|
||||
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
|
||||
if isinstance(self.device, tuple):
|
||||
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
|
||||
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
|
||||
contiguous = kwargs.pop("contiguous", True)
|
||||
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
|
||||
return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src]
|
||||
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
|
||||
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
|
||||
|
||||
# ***** rng hlops *****
|
||||
@@ -904,7 +883,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
|
||||
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
|
||||
"""
|
||||
Compute the gradient of the targets with respect to self.
|
||||
|
||||
@@ -921,29 +900,17 @@ class Tensor(SimpleMathTrait):
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
rets = []
|
||||
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
|
||||
target_uops = [x.lazydata.lbs[i] for x in targets]
|
||||
grads = compute_gradient(uop, grad, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
|
||||
ret.append(y)
|
||||
rets.append(ret)
|
||||
target_uops = [x.lazydata for x in targets]
|
||||
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None:
|
||||
if materialize_grads: y = x.const_like(0)
|
||||
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
ret.append(y)
|
||||
rets.append(ret)
|
||||
# create returned Tensors
|
||||
if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
|
||||
device=t.device) for t,u in zip(targets, zip(*rets))]
|
||||
|
||||
def _deepwalk(self) -> list[Tensor]:
|
||||
def _walk(node:Tensor, visited:set[Tensor]):
|
||||
visited.add(node)
|
||||
# if tensor is not leaf, reset grad
|
||||
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
||||
if ctx:
|
||||
for i in cast(Function, node._ctx).parents:
|
||||
if i not in visited: yield from _walk(i, visited)
|
||||
yield node
|
||||
return list(_walk(self, set()))
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
|
||||
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
||||
"""
|
||||
@@ -956,31 +923,13 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.grad.numpy())
|
||||
```
|
||||
"""
|
||||
toposorted = self._deepwalk()
|
||||
if gradient is None:
|
||||
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
|
||||
toposort_uop = self.lazydata.toposort
|
||||
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
||||
self.grad = gradient
|
||||
for t0 in reversed(toposorted):
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
ctx = cast(Function, t0._ctx)
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
|
||||
grads = ctx.backward(t0.grad.lazydata)
|
||||
_METADATA.reset(token)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
|
||||
f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
if not retain_graph: del t0._ctx
|
||||
all_uops = self.lazydata.toposort
|
||||
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
|
||||
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
|
||||
# clear contexts
|
||||
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
return self
|
||||
|
||||
# ***** movement low level ops *****
|
||||
@@ -1004,7 +953,7 @@ class Tensor(SimpleMathTrait):
|
||||
# resolve -1
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
|
||||
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
|
||||
|
||||
def expand(self, shape, *args) -> Tensor:
|
||||
"""
|
||||
@@ -1037,7 +986,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
return F.Permute.apply(self, order=order_arg)
|
||||
return self._apply_uop(UOp.permute, arg=order_arg)
|
||||
|
||||
def flip(self, axis, *args) -> Tensor:
|
||||
"""
|
||||
@@ -1057,7 +1006,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
return F.Flip.apply(self, axis=axis_arg)
|
||||
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
|
||||
|
||||
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
|
||||
"""
|
||||
@@ -1077,7 +1026,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||
return F.Shrink.apply(self, arg=tuple(shrink_arg))
|
||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
||||
|
||||
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
@@ -1121,7 +1070,8 @@ class Tensor(SimpleMathTrait):
|
||||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
|
||||
def _constant(x:Tensor,px,v):
|
||||
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
||||
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
||||
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
||||
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
|
||||
@@ -1278,7 +1228,7 @@ class Tensor(SimpleMathTrait):
|
||||
self._getitem(indices).assign(v)
|
||||
return
|
||||
# NOTE: check that setitem target is valid first
|
||||
if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
|
||||
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
||||
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
@@ -1611,10 +1561,10 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** reduce ops *****
|
||||
|
||||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
|
||||
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
|
||||
if self.ndim == 0: axis = ()
|
||||
ret = fxn.apply(self, axis=axis)
|
||||
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
||||
|
||||
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
@@ -1641,7 +1591,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.sum(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
|
||||
@@ -1668,7 +1618,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.prod(axis=1).numpy())
|
||||
```
|
||||
"""
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
|
||||
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
|
||||
"""
|
||||
@@ -1691,7 +1641,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.max(axis=1, keepdim=True).numpy())
|
||||
```
|
||||
"""
|
||||
return self._reduce(F.Max, axis, keepdim)
|
||||
return self._reduce(Ops.MAX, axis, keepdim)
|
||||
|
||||
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
||||
|
||||
@@ -2528,7 +2478,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([False, True]).logical_not().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
|
||||
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
||||
def neg(self):
|
||||
"""
|
||||
Negates the tensor element-wise.
|
||||
@@ -2542,12 +2492,12 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
Returns a contiguous tensor.
|
||||
"""
|
||||
return F.Contiguous.apply(self)
|
||||
return self._apply_uop(UOp.contiguous)
|
||||
def contiguous_backward(self):
|
||||
"""
|
||||
Inserts a contiguous operation in the backward pass.
|
||||
"""
|
||||
return F.ContiguousBackward.apply(self)
|
||||
return self._apply_uop(UOp.contiguous_backward)
|
||||
def log(self):
|
||||
"""
|
||||
Computes the natural logarithm element-wise.
|
||||
@@ -2558,7 +2508,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([1., 2., 4., 8.]).log().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.log2()*math.log(2)
|
||||
def log2(self):
|
||||
"""
|
||||
Computes the base-2 logarithm element-wise.
|
||||
@@ -2569,7 +2519,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([1., 2., 4., 8.]).log2().numpy())
|
||||
```
|
||||
"""
|
||||
return self.log()/math.log(2)
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
||||
def exp(self):
|
||||
"""
|
||||
Computes the exponential function element-wise.
|
||||
@@ -2580,7 +2530,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([0., 1., 2., 3.]).exp().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.mul(1/math.log(2)).exp2()
|
||||
def exp2(self):
|
||||
"""
|
||||
Computes the base-2 exponential function element-wise.
|
||||
@@ -2591,8 +2541,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Exp.apply(self*math.log(2))
|
||||
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
||||
def relu(self):
|
||||
"""
|
||||
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
||||
@@ -2603,7 +2552,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Relu.apply(self)
|
||||
return (self>0).where(self, 0)
|
||||
|
||||
def sigmoid(self):
|
||||
"""
|
||||
@@ -2639,7 +2588,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
||||
def rsqrt(self):
|
||||
"""
|
||||
Computes the reciprocal of the square root of the tensor element-wise.
|
||||
@@ -2657,7 +2606,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
||||
def cos(self):
|
||||
"""
|
||||
Computes the cosine of the tensor element-wise.
|
||||
@@ -2816,7 +2765,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Sign.apply(self)
|
||||
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
||||
def abs(self):
|
||||
"""
|
||||
Computes the absolute value of the tensor element-wise.
|
||||
@@ -2834,7 +2783,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
|
||||
```
|
||||
"""
|
||||
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
|
||||
|
||||
# ***** activation functions *****
|
||||
|
||||
@@ -3112,7 +3061,7 @@ class Tensor(SimpleMathTrait):
|
||||
# for each dimension, check either dim is 1, or it does not change
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
return F.Expand.apply(self.reshape(shape), shape=new_shape)
|
||||
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
@@ -3156,7 +3105,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.add(Tensor([[2.0], [3.5]])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.Add.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.add, x, reverse)
|
||||
|
||||
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3197,7 +3146,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.Mul.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
|
||||
|
||||
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3210,7 +3159,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
||||
```
|
||||
"""
|
||||
return F.IDiv.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
|
||||
|
||||
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3245,7 +3194,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
||||
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
||||
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3261,7 +3210,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
|
||||
|
||||
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3276,7 +3225,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
|
||||
|
||||
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
@@ -3291,7 +3240,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
|
||||
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
|
||||
|
||||
def bitwise_not(self) -> Tensor:
|
||||
"""
|
||||
@@ -3422,7 +3371,7 @@ class Tensor(SimpleMathTrait):
|
||||
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
|
||||
cond, x = self._broadcasted(x, match_dtype=False)
|
||||
cond, y = cond._broadcasted(y, match_dtype=False)
|
||||
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
|
||||
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
|
||||
|
||||
@@ -3452,9 +3401,9 @@ class Tensor(SimpleMathTrait):
|
||||
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
|
||||
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
|
||||
|
||||
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
|
||||
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
|
||||
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
|
||||
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
|
||||
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
|
||||
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
|
||||
|
||||
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
|
||||
|
||||
@@ -3800,8 +3749,8 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
|
||||
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
|
||||
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
|
||||
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
|
||||
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
|
||||
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
@@ -3826,7 +3775,7 @@ class Tensor(SimpleMathTrait):
|
||||
tmp = self.bitcast(old_uint)
|
||||
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
|
||||
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
|
||||
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
|
||||
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
|
||||
|
||||
def float(self) -> Tensor:
|
||||
"""
|
||||
@@ -4000,5 +3949,5 @@ def _metadata_wrapper(fn):
|
||||
|
||||
if TRACEMETA >= 1:
|
||||
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
|
||||
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
|
||||
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
|
||||
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))
|
||||
|
||||
@@ -301,17 +301,17 @@
|
||||
const kernelListParent = document.querySelector(".container.kernel-list-parent");
|
||||
const kernelList = document.querySelector(".container.kernel-list");
|
||||
kernelList.innerHTML = "";
|
||||
kernels.forEach((k, i) => {
|
||||
kernels.forEach(([key, items], i) => {
|
||||
const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "",
|
||||
style: "overflow-x: auto; cursor: initial;" });
|
||||
if (i === currentKernel) {
|
||||
requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
|
||||
}
|
||||
const p = Object.assign(document.createElement("p"), { id: `kernel-${k[0].kernel_name}`, innerText: k[0].kernel_name ?? "UNPARENTED",
|
||||
const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerText: key ?? "UNPARENTED",
|
||||
style: "cursor: pointer;"});
|
||||
kernelUl.appendChild(p)
|
||||
k.forEach((u, j) => {
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.upats.length}`, key: `uop-rewrite-${j}`,
|
||||
items.forEach((u, j) => {
|
||||
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
|
||||
className: (j === currentUOp && i == currentKernel) ? "active" : "" })
|
||||
if (j === currentUOp) {
|
||||
requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
|
||||
@@ -460,7 +460,7 @@
|
||||
event.preventDefault()
|
||||
currentUOp = 0;
|
||||
currentRewrite = 0;
|
||||
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
|
||||
currentKernel = Math.min(kernels.length-1, currentKernel+1);
|
||||
return main()
|
||||
}
|
||||
}
|
||||
@@ -486,7 +486,7 @@
|
||||
if (event.key == "ArrowDown") {
|
||||
event.preventDefault()
|
||||
currentRewrite = 0;
|
||||
const totalUOps = kernels[currentKernel].length-1;
|
||||
const totalUOps = kernels[currentKernel][1].length-1;
|
||||
currentUOp = Math.min(totalUOps, currentUOp+1)
|
||||
main()
|
||||
}
|
||||
|
||||
@@ -2,8 +2,7 @@
|
||||
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Callable, TypedDict
|
||||
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
|
||||
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
@@ -13,54 +12,30 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
|
||||
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4",
|
||||
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
|
||||
|
||||
# ** API spec
|
||||
# VIZ API
|
||||
|
||||
@dataclass
|
||||
class GraphRewriteMetadata:
|
||||
"""Overview of a tracked rewrite to viz the sidebar"""
|
||||
loc: tuple[str, int]
|
||||
"""File_path, Lineno"""
|
||||
code_line: str
|
||||
"""The Python line calling graph_rewrite"""
|
||||
kernel_name: str
|
||||
"""The kernel calling graph_rewrite"""
|
||||
upats: list[tuple[tuple[str, int], str, float]]
|
||||
"""List of all the applied UPats"""
|
||||
class GraphRewriteMetadata(TypedDict):
|
||||
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
|
||||
match_count: int # total match count in this context
|
||||
|
||||
@dataclass
|
||||
class GraphRewriteDetails(GraphRewriteMetadata):
|
||||
"""Full details about a single call to graph_rewrite"""
|
||||
uops: list[UOp]
|
||||
graphs: list[dict]
|
||||
"""Sink at every step of graph_rewrite + the json serialized version"""
|
||||
diffs: list[list[str]]
|
||||
""".diff style before and after of the rewritten UOp child"""
|
||||
changed_nodes: list[list[int]]
|
||||
"""Nodes that changed at every step of graph_rewrite"""
|
||||
kernel_code: Optional[str]
|
||||
"""The program after all rewrites"""
|
||||
|
||||
# ** API functions
|
||||
graphs: list[dict] # JSON serialized UOp at every rewrite step
|
||||
uops: list[str] # strigified UOp at every rewrite step
|
||||
diffs: list[list[str]] # string diff of the single UOp that changed
|
||||
changed_nodes: list[list[int]] # the changed UOp id + all its parents ids
|
||||
code_line: str # source code calling graph_rewrite
|
||||
kernel_code: str|None # optionally render the final kernel code
|
||||
upats: list[tuple[tuple[str, int], str]]
|
||||
|
||||
# NOTE: if any extra rendering in VIZ fails, we don't crash
|
||||
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
|
||||
try: return fxn(*args, **kwargs)
|
||||
except Exception as e: return f"ERROR: {e}"
|
||||
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
|
||||
kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
|
||||
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
|
||||
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
|
||||
for ctx in ctxs:
|
||||
if ctx.sink.op is Ops.CONST: continue
|
||||
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
|
||||
kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
|
||||
return list(kernels.values())
|
||||
|
||||
def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
|
||||
assert isinstance(x, UOp)
|
||||
graph: dict[int, tuple[str, str, list[int], str, str]] = {}
|
||||
@@ -80,27 +55,27 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
|
||||
else: label += f"\n{x.op.name}{idx} {x.arg}"
|
||||
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff"))
|
||||
return graph
|
||||
|
||||
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
|
||||
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k),
|
||||
[{"loc": v.loc, "match_count": len(v.matches)} for v in vals]) for k,vals in zip(keys, contexts)]
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def _prg(k:Kernel): return k.to_program().src
|
||||
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
|
||||
g = GraphRewriteDetails(**asdict(metadata), uops=[ctx.sink], diffs=[], changed_nodes=[],
|
||||
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[])
|
||||
ret:GraphRewriteDetails = {"uops":[pcall(str, sink:=ctx.sink)], "graphs":[uop_to_json(sink)], "code_line":lines(ctx.loc[0])[ctx.loc[1]-1].strip(),
|
||||
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "diffs":[], "upats":[], "changed_nodes":[], **metadata}
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
g.graphs.append(uop_to_json(sink:=g.uops[0]))
|
||||
for i,(u0,u1,upat,_) in enumerate(tqdm(ctx.matches)):
|
||||
# if the match didn't result in a rewrite we move forward
|
||||
if u1 is None: continue
|
||||
for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches)):
|
||||
replaces[u0] = u1
|
||||
# first, rewrite this UOp with the current rewrite + all the matches in replaces
|
||||
new_sink = sink.substitute(replaces)
|
||||
# sanity check
|
||||
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
|
||||
# update ret data
|
||||
g.graphs.append(new_sink_js:=uop_to_json(new_sink))
|
||||
g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js])
|
||||
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
g.uops.append(sink:=new_sink)
|
||||
return g
|
||||
ret["graphs"].append(new_sink_js:=uop_to_json(new_sink))
|
||||
ret["changed_nodes"].append([id(x) for x in u1.toposort if id(x) in new_sink_js])
|
||||
ret["diffs"].append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
|
||||
ret["upats"].append((upat.location, upat.printable()))
|
||||
# TODO: this is O(n^2)!
|
||||
ret["uops"].append(str(sink:=new_sink))
|
||||
return ret
|
||||
|
||||
# Profiler API
|
||||
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
|
||||
@@ -150,13 +125,9 @@ class Handler(BaseHTTPRequestHandler):
|
||||
elif url.path == "/kernels":
|
||||
query = parse_qs(url.query)
|
||||
if (qkernel:=query.get("kernel")) is not None:
|
||||
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
|
||||
# TODO: this is O(n^2)!
|
||||
uops_strs = [pcall(str,x) for x in tqdm(g.uops)]
|
||||
# NOTE: don't use asdict because it's reserializing the uops
|
||||
jret: Any = {"loc": g.loc, "code_line": g.code_line, "kernel_name": g.kernel_name, "upats": g.upats,
|
||||
"uops": uops_strs, "graphs": g.graphs, "diffs": g.diffs, "changed_nodes": g.changed_nodes, "kernel_code": g.kernel_code}
|
||||
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
|
||||
kidx, ridx = int(qkernel[0]), int(query["idx"][0])
|
||||
jret:Any = get_details(contexts[0][kidx], contexts[1][kidx][ridx], kernels[int(qkernel[0])][1][int(query["idx"][0])])
|
||||
else: jret = kernels
|
||||
ret, content_type = json.dumps(jret).encode(), "application/json"
|
||||
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
|
||||
else: status_code = 404
|
||||
@@ -198,10 +169,11 @@ if __name__ == "__main__":
|
||||
|
||||
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
|
||||
|
||||
# NOTE: this context is a tuple of list[keys] and list[values]
|
||||
kernels = get_metadata(*contexts) if contexts is not None else []
|
||||
|
||||
if getenv("FUZZ_VIZ"):
|
||||
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
|
||||
ret = [get_details(contexts[0][i], contexts[1][i][j], args) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])]
|
||||
print(f"fuzzed {len(ret)} rewrite details")
|
||||
|
||||
perfetto_profile = to_perfetto(profile) if profile is not None else None
|
||||
|
||||
Reference in New Issue
Block a user