Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-01-27 08:07:19 -08:00
55 changed files with 1354 additions and 1255 deletions

View File

@@ -299,6 +299,10 @@ jobs:
run: NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
- name: Run 10 MLPerf Bert training steps (6 gpu)
# TODO: remove DISABLE_DROPOUT once dropout is fixed
# TODO: remove BERT_LAYERS once scheduler is fast
run: NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA Training)
@@ -309,9 +313,10 @@ jobs:
train_cifar_bf16.txt
train_cifar_wino.txt
train_cifar_one_gpu.txt
train_cifar_six_gpu.txt
train_resnet.txt
train_resnet_one_gpu.txt
train_cifar_six_gpu.txt
train_bert.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -492,6 +497,10 @@ jobs:
run: AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
- name: Run 10 MLPerf Bert training steps (6 gpu)
# TODO: remove DISABLE_DROPOUT once dropout is fixed
# TODO: remove BERT_LAYERS once scheduler is fast
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=6 DISABLE_DROPOUT=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD Training)
@@ -502,9 +511,10 @@ jobs:
train_cifar_bf16.txt
train_cifar_wino.txt
train_cifar_one_gpu.txt
train_cifar_six_gpu.txt
train_resnet.txt
train_resnet_one_gpu.txt
train_cifar_six_gpu.txt
train_bert.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View File

@@ -619,6 +619,44 @@ jobs:
if: matrix.backend=='amd'
run: python -m pytest -n=auto test/test_hcq.py test/test_tiny.py --durations=20
wintests:
strategy:
fail-fast: false
matrix:
backend: [llvm]
name: Tests on Windows (${{ matrix.backend }})
runs-on: windows-latest
timeout-minutes: 45
steps:
- name: Checkout Code
uses: actions/checkout@v4
with:
fetch-depth: 2 # NOTE: this fetches the HEAD commit of the PR
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: 3.12
- name: Cache python packages
uses: actions/cache@v4
with:
path: ${{ env.Python3_ROOT_DIR }}\Lib\site-packages
key: windows-${{ matrix.backend }}-packages-${{ hashFiles('**/setup.py') }}
- name: Install dependencies
run: pip install --user -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Check Device.DEFAULT and print some source
env:
DEBUG: 5
LLVM: 1
PYTHONPATH: ${{ github.workspace }}
run: |
python3 test/test_ops.py TestOps.test_add
- name: Run pytest
env:
DEBUG: 5
LLVM: 1
run: python -m pytest -n=auto test/test_tiny.py --durations=20
#testunicorn:
# name: ARM64 unicorn Test
# runs-on: ubuntu-latest

View File

@@ -9,7 +9,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not
## Frontend
Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of [UOps](../developer/uop.md).
Everything in [Tensor](../tensor/index.md) is syntactic sugar around constructing a graph of [UOps](../developer/uop.md).
The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not all UOps will actually become realized. There's two types of UOps, base and view. base contains compute into a contiguous buffer, and view is a view (specified by a ShapeTracker). Inputs to a base can be either base or view, inputs to a view can only be a single base.

View File

@@ -1,33 +0,0 @@
::: tinygrad.function
options:
members: [
"Contiguous",
"ContiguousBackward",
"Cast",
"Neg",
"Reciprocal",
"Sin",
"Relu",
"Log",
"Exp",
"Sqrt",
"Sigmoid",
"Sign",
"Less",
"Eq",
"Xor",
"Add",
"Sub",
"Mul",
"Div",
"Where",
"Sum",
"Max",
"Expand",
"Reshape",
"Permute",
"Pad",
"Shrink",
"Flip",
]
show_source: false

View File

@@ -11,7 +11,6 @@ from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
from tinygrad.nn.state import get_state_dict, get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import Context, BEAM, WINO, getenv, colored, prod
from tinygrad.multi import MultiLazyBuffer
cifar_mean = [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]
cifar_std = [0.24703225141799082, 0.24348516474564, 0.26158783926049628]
@@ -35,8 +34,6 @@ class UnsyncedBatchNorm:
self.num_batches_tracked = Tensor.zeros(1, dtype=dtypes.int, requires_grad=False)
def __call__(self, x:Tensor):
if isinstance(x.lazydata, MultiLazyBuffer): assert x.lazydata.axis is None or x.lazydata.axis == 0 and len(x.lazydata.lbs) == self.num_devices
xr = x.reshape(self.num_devices, -1, *x.shape[1:]).cast(dtypes.float32)
batch_mean, batch_invstd = self.calc_stats(xr)
ret = xr.batchnorm(

View File

@@ -247,11 +247,11 @@ if __name__ == "__main__":
fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/resolve/main/model-00004-of-00004.safetensors", "model-00004-of-00004.safetensors", subdir="llama3-8b-sfr")
args.model = fetch("https://huggingface.co/TriAiExperiments/SFR-Iterative-DPO-LLaMA-3-8B-R/raw/main/model.safetensors.index.json", "model.safetensors.index.json", subdir="llama3-8b-sfr")
elif args.size == "70B":
subdir = "Llama-3.1-Nemotron-70B-Instruct-HF"
args.model = fetch("https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir)
subdir = "DeepSeek-R1-Distill-Llama-70B"
args.model = fetch("https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model.safetensors.index.json?download=true", "model.safetensors.index.json", subdir=subdir)
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir=subdir)
for i in range(30):
fetch(f"https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct-HF/resolve/main/model-{i+1:05d}-of-00030.safetensors?download=true", f"model-{i+1:05d}-of-00030.safetensors", subdir=subdir)
for i in range(17):
fetch(f"https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/resolve/main/model-{i+1:05d}-of-000017.safetensors?download=true", f"model-{i+1:05d}-of-000017.safetensors", subdir=subdir)
assert args.model is not None, "please provide --model option"

View File

@@ -204,7 +204,7 @@ def get_mlperf_bert_config():
"intermediate_size": 4096,
"max_position_embeddings": 512,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"num_hidden_layers": getenv("BERT_LAYERS", 24),
"type_vocab_size": 2,
"vocab_size": 30522
}

View File

@@ -786,7 +786,8 @@ def train_rnnt():
pass
@TinyJit
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor,
masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
optimizer.zero_grad()
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
@@ -802,18 +803,15 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
optimizer.step()
scheduler.step()
return loss.realize()
return loss.realize(), global_norm.realize()
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
masked_lm_weights:Tensor, next_sentence_labels:Tensor):
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return {
"masked_lm_accuracy": masked_lm_accuracy.realize(),
"next_sentence_accuracy": seq_relationship_accuracy.realize(),
"masked_lm_loss": masked_lm_loss.realize(),
"next_sentence_loss": next_sentence_loss.realize()
}
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
def train_bert():
# NOTE: pip install tensorflow, wandb required
@@ -949,7 +947,7 @@ def train_bert():
previous_step = None
if ckpt:=getenv("RESUME", ""):
load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt))
start_step = int(scheduler_wd.epoch_counter.numpy().item())
start_step = int(scheduler_wd.epoch_counter.item())
print(f"resuming from {ckpt} at step {start_step}")
if RUNMLPERF:
@@ -975,7 +973,7 @@ def train_bert():
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"])
@@ -992,7 +990,7 @@ def train_bert():
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.numpy().item()
loss = loss.item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
@@ -1002,7 +1000,7 @@ def train_bert():
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st,
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
@@ -1037,12 +1035,10 @@ def train_bert():
GlobalCounters.reset()
st = time.time()
eval_result: dict[str, Tensor] = eval_step_bert(model,
lm_acc, clsf_acc, lm_loss, clsf_loss = eval_step_bert(model,
eval_data["input_ids"], eval_data["segment_ids"], eval_data["input_mask"], eval_data["masked_lm_positions"],
eval_data["masked_lm_ids"], eval_data["masked_lm_weights"], eval_data["next_sentence_labels"])
lm_loss, clsf_loss = eval_result["masked_lm_loss"].item(), eval_result["next_sentence_loss"].item()
lm_acc, clsf_acc = eval_result["masked_lm_accuracy"].item(), eval_result["next_sentence_accuracy"].item()
lm_acc, clsf_acc, lm_loss, clsf_loss = lm_acc.item(), clsf_acc.item(), lm_loss.item(), clsf_loss.item()
eval_lm_losses.append(lm_loss)
eval_clsf_losses.append(clsf_loss)
@@ -1059,7 +1055,7 @@ def train_bert():
return
if getenv("RESET_STEP", 1): eval_step_bert.reset()
del eval_data, eval_result
del eval_data
avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses)
avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)
avg_lm_acc = sum(eval_lm_accs) / len(eval_lm_accs)

View File

@@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -5,7 +5,7 @@ export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -4,7 +4,7 @@ export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

View File

@@ -5,7 +5,7 @@ export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"

167
extra/amdpci/am_smi.py Normal file
View File

@@ -0,0 +1,167 @@
import time, mmap, sys, shutil, os, glob
from tinygrad.helpers import to_mv, DEBUG, colored, ansilen
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.autogen.am import smu_v13_0_0
from tinygrad.runtime.support.am.amdev import AMDev, AMMemoryManager
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
AM_VERSION = 0xA0000002
def bold(s): return f"\033[1m{s}\033[0m"
def color_temp(temp):
if temp >= 87: return colored(f"{temp:>4}", "red")
elif temp >= 80: return colored(f"{temp:>4}", "yellow")
return f"{temp:>4}"
def color_voltage(voltage): return colored(f"{voltage/1000:>5.3f}V", "cyan")
def draw_bar(percentage, width=40, fill='', empty=''):
filled_width = int(width * percentage)
bar = fill * filled_width + empty * (width - filled_width)
return f'[{bar}] {percentage*100:5.1f}%'
def same_line(strs:list[list[str]], split=8) -> list[str]:
ret = []
max_width_in_block = [max(ansilen(line) for line in block) for block in strs]
max_height = max(len(block) for block in strs)
for i in range(max_height):
line = []
for bid, block in enumerate(strs):
if i < len(block): line.append(block[i] + ' ' * (split + max_width_in_block[bid] - ansilen(block[i])))
else: line.append(' ' * (split + max_width_in_block[bid]))
ret.append(' '.join(line))
return ret
def get_bar0_size(pcibus):
resource_file = f"/sys/bus/pci/devices/{pcibus}/resource"
if not os.path.exists(resource_file): raise FileNotFoundError(f"Resource file not found: {resource_file}")
with open(resource_file, "r") as f: lines = f.readlines()
bar0_info = lines[0].split()
if len(bar0_info) < 3: raise ValueError("Unexpected resource file format for BAR0.")
start_hex, end_hex, _flags = bar0_info
return int(end_hex, 16) - int(start_hex, 16) + 1
class AMSMI(AMDev):
def __init__(self, pcibus, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
self.pcibus = pcibus
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
self._run_discovery()
self._build_regs()
if self.reg("regSCRATCH_REG7").read() != AM_VERSION:
raise Exception(f"Unsupported AM version: {self.reg('regSCRATCH_REG7').read():x}")
self.is_booting, self.smi_dev = True, True
self.partial_boot = True # do not init anything
self.mm = AMMemoryManager(self, self.vram_size)
# Initialize IP blocks
self.soc21:AM_SOC21 = AM_SOC21(self)
self.gmc:AM_GMC = AM_GMC(self)
self.ih:AM_IH = AM_IH(self)
self.psp:AM_PSP = AM_PSP(self)
self.smu:AM_SMU = AM_SMU(self)
class SMICtx:
def __init__(self):
self.devs = []
self.opened_pcidevs = []
self.opened_pci_resources = {}
self.prev_lines_cnt = 0
def _open_am_device(self, pcibus):
if pcibus not in self.opened_pci_resources:
bar_fds = {bar: os.open(f"/sys/bus/pci/devices/{pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
bar_size = {0: get_bar0_size(pcibus), 2: os.fstat(bar_fds[2]).st_size, 5: os.fstat(bar_fds[5]).st_size}
def map_pci_range(bar):
return to_mv(libc.mmap(0, bar_size[bar], mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, bar_fds[bar], 0), bar_size[bar])
self.opened_pci_resources[pcibus] = (map_pci_range(0), None, map_pci_range(5).cast('I'))
try:
self.devs.append(AMSMI(pcibus, *self.opened_pci_resources[pcibus]))
except Exception as e:
if DEBUG >= 2: print(f"Failed to open AM device {pcibus}: {e}")
return
self.opened_pcidevs.append(pcibus)
if DEBUG >= 2: print(f"Opened AM device {pcibus}")
def rescan_devs(self):
pattern = os.path.join('/tmp', 'am_*.lock')
for d in [f[8:-5] for f in glob.glob(pattern)]:
if d not in self.opened_pcidevs:
self._open_am_device(d)
for d in self.devs:
if d.reg("regSCRATCH_REG7").read() != AM_VERSION:
self.devs.remove(d)
self.opened_pcidevs.remove(d.pcibus)
os.system('clear')
if DEBUG >= 2: print(f"Removed AM device {d.pcibus}")
def collect(self): return {d: d.smu.read_metrics() for d in self.devs}
def draw(self):
terminal_width, _ = shutil.get_terminal_size()
dev_metrics = self.collect()
dev_content = []
for dev, metrics in dev_metrics.items():
device_line = [f"PCIe device: {bold(dev.pcibus)}"] + [""]
activity_line = [f"GFX Activity {draw_bar(metrics.SmuMetrics.AverageGfxActivity / 100, 50)}"] \
+ [f"UCLK Activity {draw_bar(metrics.SmuMetrics.AverageUclkActivity / 100, 50)}"] + [""]
# draw_metrics_table(metrics, dev)
temps_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_TEMP_e__enumvalues.items()
if k < smu_v13_0_0.TEMP_COUNT and metrics.SmuMetrics.AvgTemperature[k] != 0]
temps_table = ["=== Temps (C) ==="] + [f"{name:<15}: {color_temp(metrics.SmuMetrics.AvgTemperature[k])}" for k, name in temps_keys]
voltage_keys = [(k, name) for k, name in smu_v13_0_0.c__EA_SVI_PLANE_e__enumvalues.items() if k < smu_v13_0_0.SVI_PLANE_COUNT]
power_table = ["=== Power ==="] \
+ [f"Fan Speed: {metrics.SmuMetrics.AvgFanRpm} RPM"] \
+ [f"Fan Power: {metrics.SmuMetrics.AvgFanPwm} %"] \
+ [f"Power: {metrics.SmuMetrics.AverageSocketPower}W " +
draw_bar(metrics.SmuMetrics.AverageSocketPower / metrics.SmuMetrics.dGPU_W_MAX, 16)] \
+ ["", "=== Voltages ==="] + [f"{name:<24}: {color_voltage(metrics.SmuMetrics.AvgVoltage[k])}" for k, name in voltage_keys]
frequency_table = ["=== Frequencies ===",
f"GFXCLK Target : {metrics.SmuMetrics.AverageGfxclkFrequencyTarget} MHz",
f"GFXCLK PreDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPreDs} MHz",
f"GFXCLK PostDs : {metrics.SmuMetrics.AverageGfxclkFrequencyPostDs} MHz",
f"FCLK PreDs : {metrics.SmuMetrics.AverageFclkFrequencyPreDs} MHz",
f"FCLK PostDs : {metrics.SmuMetrics.AverageFclkFrequencyPostDs} MHz",
f"MCLK PreDs : {metrics.SmuMetrics.AverageMemclkFrequencyPreDs} MHz",
f"MCLK PostDs : {metrics.SmuMetrics.AverageMemclkFrequencyPostDs} MHz",
f"VCLK0 : {metrics.SmuMetrics.AverageVclk0Frequency} MHz",
f"DCLK0 : {metrics.SmuMetrics.AverageDclk0Frequency} MHz",
f"VCLK1 : {metrics.SmuMetrics.AverageVclk1Frequency} MHz",
f"DCLK1 : {metrics.SmuMetrics.AverageDclk1Frequency} MHz"]
dev_content.append(device_line + activity_line + same_line([temps_table, power_table, frequency_table]))
raw_text = 'AM Monitor'.center(terminal_width) + "\n" + "=" * terminal_width + "\n\n"
for i in range(0, len(dev_content), 2):
if i + 1 < len(dev_content): raw_text += '\n'.join(same_line([dev_content[i], dev_content[i+1]]))
else: raw_text += '\n'.join(dev_content[i])
if i + 2 < len(dev_content): raw_text += "\n" + "=" * terminal_width + "\n\n"
sys.stdout.write(f'\033[{self.prev_lines_cnt}A')
sys.stdout.flush()
print(raw_text)
self.prev_lines_cnt = len(raw_text.splitlines()) + 2
if __name__ == "__main__":
try:
os.system('clear')
smi_ctx = SMICtx()
while True:
smi_ctx.rescan_devs()
smi_ctx.draw()
time.sleep(1)
except KeyboardInterrupt: print("Exiting...")

View File

@@ -0,0 +1,3 @@
#!/bin/bash
PYTHON_PATH=$(readlink -f $(which python3))
sudo setcap 'cap_dac_override,cap_sys_rawio,cap_sys_admin=ep' $PYTHON_PATH

2
extra/amdpci/setup_vfio.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/bash
sudo modprobe vfio-pci disable_idle_d3=1

View File

@@ -63,15 +63,15 @@ class BertForPretraining:
def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
valid = masked_lm_ids != 0
masked_lm_predictions = prediction_logits.log_softmax(dtype=dtypes.float).argmax(-1)
masked_lm_accuracy = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_predictions = prediction_logits.argmax(-1)
masked_lm_correct = (masked_lm_predictions == masked_lm_ids) * valid
masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
seq_relationship_predictions = seq_relationship_logits.log_softmax(dtype=dtypes.float).argmax(-1)
seq_relationship_accuracy = (seq_relationship_predictions == next_sentence_labels)
seq_relationship_predictions = seq_relationship_logits.argmax(-1)
seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels)
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
return masked_lm_accuracy.sum() / valid.sum(), seq_relationship_accuracy.mean(), masked_lm_loss, next_sentence_loss
return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info

View File

@@ -1,41 +1,15 @@
import functools, io, math
from typing import cast, Literal
from tinygrad.tensor import Tensor, _broadcast_shape, ConstType, ReductionStr
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.dtype import ImageDType, dtypes, DType
from tinygrad.helpers import prod, flatten, make_tuple
from extra.onnx import dtype_parse, _cached_to_python_const
import numpy as np
# **************** Free Ops ****************
# ***** Property/Graph Ops *****
def Identity(x:Tensor): return x
# TODO: fix buffer_parse
def Add(x:Tensor, other:Tensor, broadcast=None, axis=None): return x + other if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + other).cast(x.dtype)
def Sub(x:Tensor|int, other:Tensor): return x - other # some test has input as int
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def BitwiseNot(x:Tensor): return ~x
def BitwiseOr(x:Tensor, y:Tensor): return x | y
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# **************** Simple Ops ****************
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
def Div(x:Tensor, other:Tensor): return (x/other).cast(x.dtype)
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None,
value_floats:list[float]|None=None, value_int:int|None=None, value_ints:list[int]|None=None,
value_string:str|None=None, value_strings:list[str]|None=None):
def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:float|None=None, value_floats:list[float]|None=None,
value_int:int|None=None, value_ints:list[int]|None=None, value_string:str|None=None, value_strings:list[str]|None=None):
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
if value_floats is not None: return Tensor(list(value_floats), dtype=dtypes.float32, requires_grad=False)
@@ -44,21 +18,79 @@ def Constant(sparse_value:Tensor|None=None, value:Tensor|None=None, value_float:
if value_string is not None or value_strings is not None and sparse_value is not None:
raise NotImplementedError('Constant OP not implemented for value_string, value_strings and sparse_value')
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed for the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(encoded_stream))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
if value is None: value = Tensor(0, dtype=dtypes.float32)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
def Size(data:Tensor): return data.numel()
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
# ***** Unary Ops (math) *****
def Not(x:Tensor): return x.logical_not()
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None):
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
# ***** Unary Ops (activation) *****
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1:Softmax_1, 13:Softmax_13}
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def FastGelu(x:Tensor, bias:Tensor|None=None):
# this is tanh approximated
return (x + bias).gelu() if bias is not None else x.gelu()
# TODO: fix this
def PRelu(X:Tensor, slope:Tensor):
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope # HACK OnnxBackendPyTorchConvertedModelTest HAS WEIRD SLOPE WHERE IT'S [0.25, 0.25, 0.25] FOR ANY X.SHAPE
slope = slope[0] if slope.shape[-1] != X.shape[-1] else slope
return (X > 0).where(X, X * slope)
def LeakyRelu(X:Tensor, alpha:float=0.01): return X.leakyrelu(alpha)
def ThresholdedRelu(X:Tensor, alpha:float=1.0): return (X > alpha).where(X, 0)
def Softmax_1(x:Tensor, axis:int=1): return x.softmax(axis)
def Softmax_13(x:Tensor, axis:int=-1): return x.softmax(axis)
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
def LogSoftmax(x: Tensor, axis:int=-1): return x.log_softmax(axis)
def Clip(x: Tensor, min:Tensor|None=None, max:Tensor|None=None): # noqa: A002
return x.clip(float('-inf') if min is None else min, float('inf') if max is None else max).cast(x.dtype)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
# ***** Unary Ops (broadcasted) *****
def Add(x:Tensor,y:Tensor, broadcast=None, axis=None): return x + y if x.dtype == dtypes.float or isinstance(x.dtype, ImageDType) else (x + y).cast(x.dtype)
def Sub(x:Tensor|int,y:Tensor): return x - y # some test has input as int
def Div(x:Tensor,y:Tensor): return (x/y).cast(x.dtype)
def Less(x:Tensor,y:Tensor): return x < y
def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def And(x:Tensor,y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor,y:Tensor): return (x==y).where(x, True)
def BitwiseAnd(x:Tensor,y:Tensor): return x & y
def BitwiseOr(x:Tensor,y:Tensor): return x | y
def BitwiseXor(x:Tensor,y:Tensor): return x ^ y
def BitwiseNot(x:Tensor): return ~x
# ***** Casting Ops *****
# TODO: saturate
def Cast(x:Tensor, to:int, saturate:int=1): return x.cast(dtype_parse(to))
def CastLike(x:Tensor, target_type:Tensor, saturate:int=1): return x.cast(target_type.dtype)
# ***** Reduce Ops *****
def Max(*data_0:Tensor): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0:Tensor): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0:Tensor): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0:Tensor): return Sum(*data_0) / len(data_0)
def _axes(axes, noop_with_empty_axes): return axes or ([] if noop_with_empty_axes else None)
def ReduceMax(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return data.max(_axes(axes, noop_with_empty_axes), keepdim=keepdims)
@@ -80,27 +112,27 @@ def ReduceLogSum(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_wit
return ReduceSum(data, axes, keepdims, noop_with_empty_axes).log()
def ReduceLogSumExp(data:Tensor, axes:list[int]|None=None, keepdims:int=1, noop_with_empty_axes:int=0):
return ReduceSum(data.exp(), axes, keepdims, noop_with_empty_axes).log()
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0):
return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
def OptionalHasElement(x:Tensor|None=None): return Tensor(x is not None and x.numel() > 0)
def OptionalGetElement(x:Tensor|None=None): return x if x is not None else Tensor([])
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
def Range(start:float|int, limit:float|int, delta:float|int): return Tensor.arange(start=start, stop=limit, step=delta)
def Shape(data:Tensor, end:int|None=None, start:int=0): return Tensor(data.shape[start:end], dtype=dtypes.int64)
def Size(data:Tensor): return data.numel()
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
# ***** Movement Ops *****
def Reshape(data:Tensor, shape:list[int], allowzero:int=0):
return data.reshape([x if x != 0 else (0 if allowzero else data.shape[i]) for i,x in enumerate(shape)])
def Flatten(x:Tensor, axis:int=1): return x.reshape(prod(x.shape[0:axis]), -1)
def Expand(x:Tensor, shape:list[int]): return x.expand(_broadcast_shape(x.shape, tuple(shape)))
def Shrink(x:Tensor, bias:float=0.0, lambd:float=0.5): return (x < -lambd)*(x+bias) + (x > lambd)*(x-bias)
def And(x:Tensor, y:Tensor): return (x==y).where(x, False)
def Or(x:Tensor, y:Tensor): return (x==y).where(x, True)
def Not(x:Tensor): return x.logical_not()
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
# TODO: add test for when axes is None
def Squeeze(data:Tensor, axes:list[int]|None=None):
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
def Tile(x:Tensor, repeats:list[int]): return x.repeat(repeats)
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
def Slice(data:Tensor, starts:list[int], ends:list[int], axes:list[int]|None=None, steps:list[int]|None=None):
axes = axes or list(range(data.ndim))
steps = steps or [1]*data.ndim
@@ -113,83 +145,9 @@ def Split(data:Tensor, split:list[int]|None=None, num_outputs:int=0, axis:int=0)
if split is None: split = [sz // num_outputs + (1 if i < sz % num_outputs else 0) for i in range(num_outputs)]
return data.split(split, axis)
# TODO: add test for when axes is None
def Squeeze(data:Tensor, axes:list[int]|None=None):
return data.squeeze() if axes is None else functools.reduce(lambda d, dim: d.squeeze(dim), sorted(axes, reverse=True), data)
def Unsqueeze(data:Tensor, axes:list[int]): return functools.reduce(lambda d, dim: d.unsqueeze(dim), sorted(axes), data)
def Binarizer(x:Tensor, threshold:float=0.0): return (x > threshold).float()
def ArgMax(x:Tensor, axis:int=0, keepdims:int=1, select_last_index:int=0):
if select_last_index: return ((x.shape[axis]-1) - x.flip(axis).argmax(axis, keepdim=keepdims)).cast(dtypes.int64)
return x.argmax(axis, keepdim=keepdims).cast(dtypes.int64)
def ArgMin(x, axis:int=0, keepdims:int=1, select_last_index:int=0): return ArgMax(-x, axis=axis, keepdims=keepdims, select_last_index=select_last_index)
def Concat(*xs:Tensor, axis:int): return Tensor.cat(*xs, dim=axis)
def Transpose(x:Tensor, perm:list[int]|None=None): return x.permute(order=list(range(x.ndim)[::-1]) if perm is None else perm)
def ConstantOfShape(shape:list[int], value:Tensor|None=None):
if value is None: value = Tensor(0, dtype=dtypes.float32)
return Tensor.ones(*shape, dtype=value.dtype) * (value if shape != [0] else 1)
# **************** Complex Ops ****************
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
# TODO: this is copied from tinygrad/nn/__init__.py
# spatial is from opset 7 and has since been removed
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).rsqrt()
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axes, keepdim=True)
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
def _onnx_pads_to_tiny_pads(pads): return flatten(reversed([(pB,pA) for pB, pA in zip(pads, pads[len(pads)//2:])]))
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _onnx_pads_to_tiny_pads(pads):
# (padding_top, padding_left, ..., padding_bottom, padding_right, ...) -> (padding_left, padding_right, padding_top, padding_bottom, ...)
return tuple(flatten(reversed(list(zip(pads, pads[len(pads)//2:])))))
def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[int]|None=None,
mode:Literal["constant", "reflect", "edge", "wrap"]="constant", value=0):
value = constant_value or value
@@ -198,6 +156,21 @@ def Pad(x:Tensor, pads:list[int], constant_value:ConstType|None=None, axes:list[
for i,axis in enumerate(axes): real_pads[axis%x.ndim], real_pads[axis%x.ndim+x.ndim] = pads[i], pads[i+len(axes)]
return x.pad(padding=_onnx_pads_to_tiny_pads(real_pads), mode={"edge":"replicate", "wrap":"circular"}.get(mode, mode), value=value)
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
for s, x in zip(shape, axes or range(t.ndim)):
tx = t.shape[x]
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
# ***** Processing Ops *****
AUTO_PAD_OPTIONS = Literal["NOTSET", "SAME_UPPER", "SAME_LOWER", "VALID"]
def _auto_pad(pads, auto_pad: AUTO_PAD_OPTIONS):
# (padding_height, padding_width) -> (padding_top, padding_left, padding_bottom, padding_right)
if auto_pad == "SAME_UPPER": return [pads[i]//2 for i in range(len(pads))] + [pads[i]-pads[i]//2 for i in range(len(pads))]
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
@@ -206,25 +179,22 @@ def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
def AveragePool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, count_include_pad:int=0,
dilations:list[int]|int=1, pads:list[int]|int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
return X.avg_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode, count_include_pad=count_include_pad)
return X.avg_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad),
ceil_mode=ceil_mode, count_include_pad=count_include_pad)
def MaxPool(X: Tensor, kernel_shape:list[int], auto_pad:AUTO_PAD_OPTIONS="NOTSET", ceil_mode:int=0, dilations:list[int]|int=1, pads:list[int]|int=0,
storage_order:int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad)
ret = X.max_pool2d(kernel_shape, strides, dilations, pads, ceil_mode=ceil_mode)
ret = X.max_pool2d(kernel_shape, strides, dilations, _resolve_pool_pads(X, pads, kernel_shape, dilations, strides, auto_pad), ceil_mode=ceil_mode)
# tests expect indices with int64 dtype
# TODO: if there are repeated values, this is wrong
indices = ((ret.reshape(-1, 1) == X.reshape(1, -1)) * Tensor.arange(X.numel(), dtype=dtypes.int64).unsqueeze(0)).sum(1).reshape(ret.shape)
return ret.cast(X.dtype), indices.transpose(-2, -1) if storage_order else indices
def Conv(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
pads = _resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad)
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations, padding=tuple(pads))
kernel_shape:list[int]|None=None, pads:list[int]|int=0, strides:list[int]|int=1):
return X.conv2d(W, B, stride=strides, groups=group, dilation=dilations,
padding=_resolve_pool_pads(X, pads, kernel_shape or W.shape[2:], dilations, strides, auto_pad))
# src: https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
# another src: https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_conv_transpose.py
def ConvTranspose(X: Tensor, W: Tensor, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:list[int]|int=1, group:int=1,
kernel_shape:list[int]|None=None, pads:list[int]|None=None, output_shape:list[int]|None=None, output_padding:list[int]|int=0,
strides:list[int]|int=1):
@@ -247,80 +217,28 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: list[int]|None=None, kernel_shap
if outshape is not None and outshape != ret.shape: pads = _auto_pad([outshape[-2] - ret.shape[-2], outshape[-1] - ret.shape[-1]], "SAME_UPPER")
return ret.pad(_onnx_pads_to_tiny_pads(pads))
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
def SpaceToDepth(X:Tensor, blocksize:int):
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
def GlobalAveragePool(X:Tensor): return X.mean(axis=tuple(range(2, X.ndim)), keepdim=True)
def GlobalMaxPool(X:Tensor): return X.max(axis=tuple(range(2, X.ndim)), keepdim=True)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}
def Gemm(A:Tensor, B:Tensor, C:Tensor|None=None, alpha:float=1.0, beta:float=1.0, transA:int=0, transB:int=0, broadcast=0):
ret = alpha * (A.transpose(transA) @ B.transpose(transB))
if C is not None: ret = ret + beta * (C if broadcast == 0 else C.reshape([-1 if i < len(C.shape) else 1 for i in range(ret.ndim)][::-1]))
return ret
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
def Einsum(*Inputs:list[Tensor], equation:str): return Tensor.einsum(equation, *Inputs)
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def CumSum(X:Tensor, axis:int|list, exclusive:int=0, reverse:int=0):
axis = X._resolve_dim(axis[0] if isinstance(axis, list) else axis)
if reverse: X = X.flip(axis)
if exclusive: X = X.pad(tuple((1,0) if i == axis else None for i in range(X.ndim)))\
.shrink(tuple((0,X.shape[axis]) if i == axis else None for i in range(X.ndim)))
return X.cumsum(axis).flip(axis) if reverse else X.cumsum(axis)
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
return x.nll_loss(target, weight, ignore_index, reduction)
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
log_probs = scores.log_softmax(1)
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
x_shape, i_shape = x.shape, indices.shape
b = math.prod(x.shape[dim] for dim in range(batch_dims))
# NOTE: each batched dim of both input and indices are equal
x = x.reshape(b, *x.shape[batch_dims:])
indices = indices.reshape(b, *indices.shape[batch_dims:])
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
x = x.contiguous()
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
u = u.squeeze(0)
if reduction == "none": x[i] = u
elif reduction == "add": x[i] += u
elif reduction == "mul": x[i] *= u
else: raise NotImplementedError("reduction doesn't support max or min")
return x
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
def Trilu(x:Tensor, k:int=0, upper:int=1): return x.triu(k) if upper else x.tril(k)
def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, sizes:list[int]|None=None, antialias:int=0,
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
axes:list[int]|None=None, coordinate_transformation_mode:str='half_pixel', cubic_coeff_a:float=-0.75, exclude_outside:int=0,
extrapolation_value:float=0.0, keep_aspect_ratio_policy:str='stretch', mode:str='nearest', nearest_mode:str='round_prefer_floor'):
def _apply_nearest_mode(index: Tensor, input_dim, mode: str):
if mode == "round_prefer_floor": index = (index - 0.5).ceil()
elif mode == "round_prefer_ceil": index = (index + 0.5).floor()
@@ -377,116 +295,43 @@ def Resize(X:Tensor, roi:list[float]|None=None, scales:list[float]|None=None, si
X = X.gather(i, low).lerp(X.gather(i, high), perc)
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
def CenterCropPad(t:Tensor, shape:list[int], axes:list[int]|None=None):
shrink_arg:list[None|tuple[int,int]] = [None] * t.ndim
pad_arg:list[None|tuple[int,int]] = [None] * t.ndim
for s, x in zip(shape, axes or range(t.ndim)):
tx = t.shape[x]
if s < tx: shrink_arg[x] = (tx//2 - (s+1)//2, tx//2 + s//2)
elif s > tx: pad_arg[x] = ((s-tx)//2, (s-tx+1)//2)
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
def EyeLike(x:Tensor, dtype:int|None=None, k:int=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype is not None else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.shape[0]-k) for d in x.shape))
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode) # deprecated
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
if axis < 0: axis += x.ndim
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
if block_size == 0:
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
return scale.reshape(shape), zero_point.reshape(shape)
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
# ***** Neural Network Ops *****
# TODO: try to factor out common implementations for these ops
# https://medium.com/@zljdanceholic/groupnorm-then-batchnorm-instancenorm-layernorm-e2b2a1d350a0
def BatchNormalization(X:Tensor, scale:Tensor, B:Tensor, input_mean:Tensor, input_var:Tensor, epsilon:float=1e-05, momentum:float=0.9,
training_mode:int=0, spatial=1, is_test=0):
if training_mode:
x_detached = X.detach()
current_mean = x_detached.mean(axis=(0,2,3))
y = (x_detached - current_mean.reshape(shape=[1, -1, 1, 1]))
current_var = (y*y).mean(axis=(0,2,3))
current_invstd = current_var.add(epsilon).rsqrt()
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return ((x / y_scale).round() + y_zero_point).clamp(dtypes.min(out_dtype), dtypes.max(out_dtype)).cast(out_dtype).contiguous()
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
def _quantize_linear(y:Tensor, y_scale:Tensor, y_zero_point:Tensor):
assert y_scale.dtype is dtypes.float32 and y_zero_point.dtype in {dtypes.uint8, dtypes.int8}, "used only for qlinear ops"
y = (y / y_scale + y_zero_point).round()
return y.clamp(dtypes.min(y_zero_point.dtype), dtypes.max(y_zero_point.dtype)).cast(y_zero_point.dtype)
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point: Tensor|int, B:Tensor|None=None, auto_pad:AUTO_PAD_OPTIONS="NOTSET", dilations:int|list[int]=1, group:int=1,
kernel_shape:list[int]|None=None, pads:int|list[int]=0, strides:int|list[int]=1):
x = x.int() - x_zero_point
w = w.int() - w_zero_point
y = Conv(x, w, B, auto_pad, dilations, group, kernel_shape, pads, strides)
y_scale = y_scale / (x_scale * w_scale)
return _quantize_linear(y, y_scale, y_zero_point)
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point:Tensor|int) -> Tensor:
a = a.int() - a_zero_point
b = b.int() - b_zero_point
y = Tensor.matmul(a, b, acc_dtype=dtypes.int32)
y_scale = y_scale / (a_scale * b_scale)
return _quantize_linear(y, y_scale, y_zero_point)
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None,
auto_pad: AUTO_PAD_OPTIONS = "NOTSET", dilations: int | list[int] = 1, group: int = 1, kernel_shape: list[int] | None = None,
pads: int | list[int] = 0, strides: int | list[int] = 1) -> Tensor:
x_int = x.int() - x_zero_point
w_int = w.int() - w_zero_point
return Conv(x_int, w_int, B, auto_pad, dilations, group, kernel_shape, pads, strides)
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
A_int = A.int() - a_zero_point
B_int = B.int() - b_zero_point
return Tensor.matmul(A_int, B_int, acc_dtype=dtypes.int32)
# copied from https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py
def ImageDecoder(encoded_stream:bytes, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(encoded_stream))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
raise ValueError(f"pixel_format={pixel_format!r} is not supported.")
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
N, _, *spatial_dims = size
def generate_grid(steps):
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
# **************** com.microsoft Ops ****************
running_mean = input_mean * momentum + current_mean * (1 - momentum)
running_var = input_var * momentum + current_var * (1 - momentum)
return X.batchnorm(scale, B, current_mean, current_invstd), running_mean, running_var
invstd = (input_var + epsilon).rsqrt()
return X.batchnorm(scale, B, input_mean, invstd)
def InstanceNormalization(x:Tensor, scale:Tensor, bias:Tensor, epsilon:float=1e-05):
axis = tuple(range(2, x.ndim))
mean = x.mean(axis=axis, keepdim=True)
invstd = x.sub(mean).square().mean(axis=axis, keepdim=True).add(epsilon).rsqrt()
return x.sub(mean).mul(scale.reshape(shape=[-1, 1, 1])).mul(invstd).add(bias.reshape(shape=[-1, 1, 1]))
def LayerNormalization(x:Tensor, scale:Tensor, bias:Tensor, axis:int=-1, epsilon:float=1e-05, stash_type:int=1):
assert stash_type == 1, "only float32 is supported"
axes = tuple(i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim))
mean = x.mean(axis=axes, keepdim=True)
return x.layernorm(axes, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).square().mean(axis=axes, keepdim=True).add(epsilon).rsqrt()
def GroupNormalization(x:Tensor, scale:Tensor, bias:Tensor, num_groups:int, epsilon:float=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
def FastGelu(x:Tensor, bias:Tensor|None=None):
# this is tanh approximated
return (x + bias).gelu() if bias is not None else x.gelu()
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
@@ -513,6 +358,45 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embeddin
out = embedding_sum.layernorm(eps=epsilon) * gamma + beta
return out, None, embedding_sum
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def DepthToSpace(X:Tensor, blocksize:int, mode:str="DCR"):
return X.rearrange("b (c h1 w1) h w -> b c (h h1) (w w1)" if mode=="CRD" else "b (h1 w1 c) h w -> b c (h h1) (w w1)", h1=blocksize, w1=blocksize)
def SpaceToDepth(X:Tensor, blocksize:int):
return X.rearrange("b c (h h1) (w w1) -> b (h1 w1 c) h w", h1=blocksize, w1=blocksize)
# Reimplemented here because you need legacy RNG for passing ONNX tests.
def Dropout_7(data:Tensor, ratio:float=0.5, training_mode:bool=False, seed:int|None=None):
if not training_mode: return data, Tensor.ones(data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
mask = Tensor(np.random.RandomState(seed).random(cast(tuple[int,...], data.shape)) >= ratio, requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
# 6 with 'is_test' needed for https://github.com/MTlab/onnx2caffe/raw/refs/heads/master/model/MobileNetV2.onnx
def Dropout_6(data:Tensor, ratio:float=0.5, is_test=0): return Dropout_7(data, ratio, training_mode=not is_test)
Dropout = {6:Dropout_6, 7:Dropout_7}
def LRN(x:Tensor, size:int, alpha:float=1e-4, beta:float=0.75, bias:float=1.0):
pooled_x = (x**2).rearrange('b c h w -> b 1 c (h w)').pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1)
return x / (pooled_x.reshape(x.shape) * alpha + bias).pow(beta)
def NegativeLogLikelihoodLoss(x:Tensor, target:Tensor, weight:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
return x.nll_loss(target, weight, ignore_index, reduction)
def SoftmaxCrossEntropyLoss(scores:Tensor, labels:Tensor, weights:Tensor|None=None, ignore_index:int|None=None, reduction:ReductionStr="mean"):
log_probs = scores.log_softmax(1)
return log_probs.nll_loss(labels, weights, ignore_index, reduction), log_probs
def AffineGrid(theta:Tensor, size:list[int], align_corners:int=0):
N, _, *spatial_dims = size
def generate_grid(steps):
return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device)
grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims))
base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1)
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
@@ -552,37 +436,132 @@ def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present if past is not None else out
# ***** Indexing Ops *****
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]
def Gather(x:Tensor, indices:Tensor, axis:int=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def Scatter(*args, **kwargs): return ScatterElements(*args, **kwargs) # deprecated
def GatherND(x:Tensor, indices:Tensor, batch_dims:int=0):
if batch_dims == 0: return x[tuple(i.squeeze(-1) for i in indices.split(1, -1))]
x_shape, i_shape = x.shape, indices.shape
b = math.prod(x.shape[dim] for dim in range(batch_dims))
# NOTE: each batched dim of both input and indices are equal
x = x.reshape(b, *x.shape[batch_dims:])
indices = indices.reshape(b, *indices.shape[batch_dims:])
b_idx = Tensor.arange(b, device=x.device).reshape(b, *(1,)*(indices.ndim - 2)).expand(*indices.shape[:-1])
ret = x[(b_idx,) + tuple(i.squeeze(-1) for i in indices.split(1, -1))]
return ret.reshape(*x_shape[:batch_dims], *i_shape[batch_dims:-1], *ret.shape[indices.ndim-1:])
def ScatterND(x:Tensor, indices:Tensor, updates:Tensor, reduction:Literal["none", "add", "mul"]='none'):
assert updates.shape == indices.shape[:-1] + x.shape[cast(int, indices.shape[-1]):]
x = x.contiguous()
for index, u in zip(indices.split(1, 0), updates.split(1, 0)):
i = tuple(idx.squeeze(-1) for idx in index.squeeze(0).split(1, -1))
u = u.squeeze(0)
if reduction == "none": x[i] = u
elif reduction == "add": x[i] += u
elif reduction == "mul": x[i] *= u
else: raise NotImplementedError("reduction doesn't support max or min")
return x
def ScatterElements(x: Tensor, indices: Tensor, updates: Tensor, axis=0, reduction:Literal["none", "add", "mul"]="none"):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.scatter(axis, indices, updates, {"none":None, "mul": "multiply"}.get(reduction, reduction))
def GatherElements(x:Tensor, indices:Tensor, axis:int):
indices = (indices < 0).where(x.shape[axis], 0) + indices
return x.gather(axis, indices)
def Compress(inp:Tensor, condition:list[bool], axis:int|None=None):
if axis is None:
inp = inp.flatten()
axis = 0
if axis < 0: axis += inp.ndim
con = Tensor(np.arange(len(condition))[condition]) # no boolean indexing in Tensor
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
# ***** Quantization Ops *****
def _clamp_cast(x:Tensor, dtype:DType): return x.clamp(dtypes.min(dtype), dtypes.max(dtype)).cast(dtype)
def _prepare_quantize(x, scale, zero_point, axis=1, block_size=0):
if axis < 0: axis += x.ndim
if not isinstance(zero_point, Tensor): zero_point = Tensor(zero_point, dtype=dtypes.uint8)._broadcast_to(scale.shape)
if block_size == 0:
shape = (*[1]*axis, *scale.shape, *[1]*(x.ndim - axis - scale.ndim))
return scale.reshape(shape), zero_point.reshape(shape)
return scale.repeat_interleave(block_size, dim=axis), zero_point.repeat_interleave(block_size, dim=axis)
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
def DequantizeLinear(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int=0, axis:int=1, block_size:int=0):
x_scale, x_zero_point = _prepare_quantize(x, x_scale, x_zero_point, axis, block_size)
return ((x.int() - x_zero_point) * x_scale).cast(x_scale.dtype)
def _op_integer(op, inputs:list[Tensor], zero_points:list[Tensor], **opts):
adjusted_inputs = [inp.int() - zp for inp, zp in zip(inputs, zero_points)]
return op(*adjusted_inputs, **opts)
def _qlinearop_quantized(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in quantized int
out = _op_integer(op, inputs, zero_points, **opts)
assert dtypes.is_int(out.dtype), "quantized op should've done math in int"
out_quantized = (out * prod(scales) / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def _qlinearop_float(op, inputs:list[Tensor], zero_points:list[Tensor], scales:list[Tensor], out_scale:Tensor, out_zero_point:Tensor, **opts):
# op execution is done in float32
dequantized_inputs = [(inp.int() - zp) * scale for inp, zp, scale in zip(inputs, zero_points, scales)]
out = op(*dequantized_inputs, **opts)
assert dtypes.is_float(out.dtype), "op should've done math in float"
out_quantized = (out / out_scale).round() + out_zero_point
return _clamp_cast(out_quantized, out_zero_point.dtype)
def QLinearConv(x:Tensor, x_scale:Tensor, x_zero_point:Tensor|int, w:Tensor, w_scale:Tensor, w_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point: Tensor|int, B:Tensor|None=None, **opts):
return _qlinearop_quantized(Conv, [x,w], [x_zero_point,w_zero_point], [x_scale,w_scale], y_scale, y_zero_point, **{"B":B, **opts})
def QLinearMatMul(a:Tensor, a_scale:Tensor, a_zero_point:Tensor|int, b:Tensor, b_scale:Tensor, b_zero_point:Tensor|int, y_scale:Tensor,
y_zero_point:Tensor|int) -> Tensor:
return _qlinearop_quantized(Tensor.matmul, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], y_scale, y_zero_point)
def QLinearAdd(a:Tensor, a_scale:Tensor, a_zero_point:Tensor, b:Tensor, b_scale:Tensor, b_zero_point:Tensor, c_scale:Tensor, c_zero_point:Tensor):
a = a.int() - a_zero_point
b = b.int() - b_zero_point
c = (a * a_scale + b * b_scale)
return _quantize_linear(c, c_scale, c_zero_point)
return _qlinearop_float(Tensor.add, [a,b], [a_zero_point,b_zero_point], [a_scale,b_scale], c_scale, c_zero_point)
def QLinearGlobalAveragePool(X:Tensor, x_scale:Tensor, x_zero_point:Tensor, y_scale:Tensor, y_zero_point:Tensor, channels_last:int):
assert channels_last in {0, 1}
if channels_last == 1: X = X.permute(0, 2, 3, 1)
X = (X.int() - x_zero_point) * x_scale
y = GlobalAveragePool(X)
return _quantize_linear(y, y_scale, y_zero_point)
assert channels_last == 0, "unsure what this does"
return _qlinearop_float(GlobalAveragePool, [X], [x_zero_point], [x_scale], y_scale, y_zero_point)
# **************** ai.onnx.preview.training Ops ****************
def ConvInteger(x: Tensor, w: Tensor, x_zero_point: Tensor | int = 0, w_zero_point: Tensor | int = 0, B: Tensor | None = None, **opts) -> Tensor:
return _op_integer(Conv, [x,w], [x_zero_point,w_zero_point], **{"B":B, **opts})
def MatMulInteger(A: Tensor, B: Tensor, a_zero_point: Tensor | int = 0, b_zero_point: Tensor | int = 0) -> Tensor:
return _op_integer(Tensor.matmul, [A,B], [a_zero_point,b_zero_point])
# ***** Training Ops *****
# NOTE: onnx test coverage only covers `T==0` cases, so for all `T>0` this isn't tested
# NOTE: onnx training ops actually don't need the state for optim, all the ops work in a functional way, but we still can reuse optim.py code
from tinygrad.nn.optim import Adam as TinyAdam
from tinygrad.nn.optim import SGD
def onnx_training(input_group_size):
def _decorator(func):
def __wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
def _onnx_training(input_group_size):
def __decorator(func):
def ___wrapper(R:Tensor, T:int, *inputs:Tensor, **kwargs):
R = R.detach()
groups = len(inputs) // input_group_size
ret = [func(R, T, *inps, **kwargs) for inps in (inputs[i::groups] for i in range(groups))]
return tuple(flatten(zip(*ret)))
return __wrapper
return _decorator
return ___wrapper
return __decorator
@onnx_training(3)
@_onnx_training(3)
def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:float=0.0, norm_coefficient:float=0.0):
X, G, H = (i.detach() for i in inputs)
grad = norm_coefficient * X + G
@@ -592,9 +571,10 @@ def Adagrad(R:Tensor, T:int, *inputs:Tensor, decay_factor:float=0.0, epsilon:flo
X.assign(X.detach() - r * up)
return [X, H]
@onnx_training(4)
@_onnx_training(4)
def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, epsilon:float=0.0, norm_coefficient:float=0.0,
norm_coefficient_post:float=0.0):
norm_coefficient_post:float=0.0):
from tinygrad.nn.optim import Adam as TinyAdam
X, G, V, H = inputs
G, V, H = G.detach(), V.detach(), H.detach() # TODO we shouldn't need these detaches
X.grad = norm_coefficient * X.detach() + G
@@ -610,8 +590,9 @@ def Adam(R:Tensor, T:int, *inputs:Tensor, alpha:float=0.9, beta:float=0.999, eps
X = (1 - norm_coefficient_post) * X
return [X, V, H]
@onnx_training(3)
@_onnx_training(3)
def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str, norm_coefficient:float):
from tinygrad.nn.optim import SGD
X, G, V = inputs
G, V = G.detach(), V.detach()
X.grad = (norm_coefficient * X.detach() + G) * (beta if T > 0 else 1)
@@ -620,6 +601,6 @@ def Momentum(R:Tensor, T:int, *inputs:Tensor, alpha:float, beta:float, mode:str,
opt.step()
return [X, V]
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **__):
def Gradient(*inputs:Tensor, y:str, intermediate_tensors:dict[str, Tensor], **_):
intermediate_tensors[y].backward()
return tuple([t.grad for t in inputs])

View File

@@ -22,7 +22,6 @@ nav:
- Runtime: runtime.md
- Developer:
- Intro: developer/developer.md
- Function (autodiff): developer/function.md
- UOp: developer/uop.md
- Runtime:
- developer/runtime.md

View File

@@ -44,8 +44,8 @@ def main():
else:
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True)
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False)
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus)
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")

View File

@@ -25,7 +25,7 @@ class AMPTFuzzer:
_vaddr = va.va_addr + _offset
for i in range(_n_ptes):
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
self.d.vram[pte['paddr']] = pattern # Mark this page
assert pte['valid'] == 1
@@ -41,7 +41,7 @@ class AMPTFuzzer:
frags_l = list(ctx.next(contig_range))
for f_offset, f_pt, f_pte_idx, f_n_ptes, f_pte_covers in frags_l:
for j in range(f_n_ptes):
f_pte = helper_read_entry_components(f_pt.get_entry(f_pte_idx + j))
f_pte = helper_read_entry_components(f_pt.entries[f_pte_idx + j])
assert f_pte['valid'] == 1
assert f_pte['paddr'] == start_paddr+f_offset+j*f_pte_covers, f"paddr {f_pte['paddr']:#x} not {start_paddr+f_offset+j*f_pte_covers:#x}"
@@ -53,7 +53,7 @@ class AMPTFuzzer:
def verify_memory(self, pages, pattern: int) -> bool:
for _offset, _pt, _pte_idx, _n_ptes, _pte_covers in pages:
for i in range(_n_ptes):
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
if self.d.vram[pte['paddr']] != pattern: return False
if pte['valid'] == 0: return False

View File

@@ -3,7 +3,9 @@ from tinygrad.runtime.support.am.amdev import AMMemoryManager, AMPageTableTraver
from tinygrad.helpers import mv_address
class FakeGMC:
def __init__(self): self.vm_base = 0x0
def __init__(self):
self.vm_base = 0x0
self.address_space_mask = (1 << 44) - 1
def flush_tlb(self, *args, **kwargs): pass
class FakePCIRegion:
@@ -14,7 +16,7 @@ class FakePCIDev:
class FakeAM:
def __init__(self):
self.is_booting = True
self.is_booting, self.smi_dev = True, False
self.pcidev = FakePCIDev()
self.vram = memoryview(bytearray(4 << 30))
self.gmc = FakeGMC()
@@ -72,7 +74,7 @@ class TestAMPageTable(unittest.TestCase):
for tup in results:
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
for i in range(_n_ptes):
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
assert pte['paddr'] == va + _offset + i * _pte_covers, f"Expected paddr {pte['paddr']:#x} to be {va + _offset + i * _pte_covers:#x}"
assert pte['valid'] == 1
@@ -81,7 +83,7 @@ class TestAMPageTable(unittest.TestCase):
for tup in results:
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
for i in range(_n_ptes):
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
assert pte['paddr'] == 0
assert pte['valid'] == 0
@@ -113,7 +115,7 @@ class TestAMPageTable(unittest.TestCase):
for tup in ctx.next(0x100000):
_offset, _pt, _pte_idx, _n_ptes, _pte_covers = tup
for i in range(_n_ptes):
pte = helper_read_entry_components(_pt.get_entry(_pte_idx + i))
pte = helper_read_entry_components(_pt.entries[_pte_idx + i])
assert pte['paddr'] == 0xdead0000 + _offset + i * _pte_covers, f"paddr {pte['paddr']:#x} not {0xdead0000 + _offset + i * _pte_covers:#x}"
assert pte['valid'] == 1

View File

@@ -2,7 +2,7 @@
# compare kernels created by HEAD against master
from collections import defaultdict
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Tuple, Union, cast
from typing import Callable, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
from tinygrad.codegen.kernel import Kernel, Opt
@@ -33,15 +33,15 @@ class ProcessReplayWarning(Warning): pass
def recreate_sched(ast:UOp) -> UOp:
# NOTE: process replay isn't meant to actually schedule anything
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:list[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)
# NOTE: replay with the captured renderer, not the one in master
return k.opts.render(name, cast(List,k.to_program().uops))
return k.opts.render(name, cast(list,k.to_program().uops))
# *** diff a "good" recreation against the generated version
def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
def diff(offset:int, name:str, fxn:Callable) -> tuple[int, int]|bool:
if early_stop.is_set(): return True
conn = db_connection()
cur = conn.cursor()
@@ -95,7 +95,7 @@ def _pmap(name:str, fxn:Callable, maxtasksperchild:int=16) -> None:
cur.close()
with multiprocessing.get_context("spawn").Pool(multiprocessing.cpu_count(), maxtasksperchild=maxtasksperchild) as pool:
inputs = list(range(0, row_count, PAGE_SIZE))
ret: List[Union[bool, Tuple[int, int]]] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
ret: list[tuple[int, int]|bool] = list(tqdm(pool.imap_unordered(functools.partial(diff, name=name, fxn=fxn), inputs), total=len(inputs)))
pool.close()
pool.join()
pool.terminate()

View File

@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 63)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 65)
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
def test_train_cifar(self):

View File

@@ -40,6 +40,7 @@ class TestTrain(unittest.TestCase):
check_gc()
@unittest.skipIf(CI, "slow")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_efficientnet(self):
model = EfficientNet(0)
X = np.zeros((BS,3,224,224), dtype=np.float32)
@@ -56,6 +57,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "too many buffers for webgpu and metal")
def test_transformer(self):
# this should be small GPT-2, but the param count is wrong
# (real ff_dim is 768*4)

View File

@@ -160,13 +160,14 @@ class TestIndexing(unittest.TestCase):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size)
emb_w = emb.weight.numpy()
# TODO: why is a new realize needed here
emb_w = emb.weight.realize().numpy()
x = Tensor([1,2,3,4])
with Context(NOOPT=noopt, FUSE_ARANGE=1):
GlobalCounters.reset()
z = emb(x).realize()
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
self.assertEqual(GlobalCounters.kernel_count, 3)
self.assertEqual(GlobalCounters.kernel_count, 2)
if getenv("CHECK", 1):
import torch
with torch.no_grad():

View File

@@ -8,9 +8,11 @@ from tinygrad.ops import UOp
from tinygrad.tensor import Tensor
def tensors_allocated():
gc.collect()
return sum([isinstance(x, Tensor) for x in gc.get_objects()])
def bufs_allocated():
gc.collect()
return sum([isinstance(x, Buffer) for x in gc.get_objects()])
class TestGC(unittest.TestCase):
@@ -31,7 +33,7 @@ class TestGC(unittest.TestCase):
base = tensors_allocated()
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
b = Tensor.rand(4, 4, requires_grad=True)
assert (tensors_allocated()-base == 5)
assert (tensors_allocated()-base == 4)
(a*b).mean().backward()
assert (tensors_allocated()-base == 6)
del b

View File

@@ -134,5 +134,39 @@ class TestImageDType(unittest.TestCase):
print(lst)
assert not np.any(np.isnan(lst))
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
class TestImageRealization(unittest.TestCase):
def test_image_dtype_expand(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous().realize()
self.assertEqual(it_expanded.dtype, dtypes.float32)
def test_image_dtype_expand_and_back(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4))
it2 = it_expanded.sum(3).realize()
self.assertEqual(it2.dtype, dtypes.imagef((9,27,4)))
def test_image_alu_children(self):
data = Tensor.randn(9*27*4).realize()
it = data.cast(dtypes.imagef((9,27,4))).contiguous().realize()
self.assertEqual(it.dtype, dtypes.imagef((9,27,4)))
it_expanded = it.reshape((9,27,4,1)).expand((9,27,4,4)).contiguous()
alu1 = it_expanded+1
alu2 = it_expanded.sum(3)
it_expanded.realize()
# NOTE: the parent becomes float, but the alu child will stay image until its output cannot fit the image
self.assertEqual(alu1.dtype, dtypes.imagef((9,27,4)))
alu1.realize()
self.assertEqual(alu1.dtype, dtypes.float32)
# alu2 is back in image because it fits the dtype again
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
alu2.realize()
self.assertEqual(alu2.dtype, dtypes.imagef((9,27,4)))
if __name__ == '__main__':
unittest.main()

View File

@@ -1694,7 +1694,6 @@ class TestHandCodedOpts(unittest.TestCase):
# should upcast the two Tensor.stacks
assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2
@unittest.expectedFailure # requires contiguous folding
def test_masked_upcast_wino_full(self):
with Context(WINO=1):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()

View File

@@ -997,7 +997,7 @@ class TestLinearizerFailures(unittest.TestCase):
UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"])
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"], atol=0.02)
# llama3 8B failure with BEAM=2 https://github.com/tinygrad/tinygrad/actions/runs/10150118124/job/28066519425#step:14:1, these don't compile
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test needs local")

View File

@@ -1,11 +1,11 @@
import unittest, functools, random
from typing import List
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.ops import Ops
from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes, Variable
from tinygrad.ops import Ops, UOp
from tinygrad.helpers import CI, getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
from tinygrad.multi import all_reduce, MultiLazyBuffer
from tinygrad.multi import all_reduce
import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.device import is_dtype_supported
@@ -30,7 +30,7 @@ N = 128
def _test_allreduce(t:Tensor):
aa = (t[0:64] + t[64:128] + t[128:192] + t[192:256]).repeat([4,1]).realize()
ts = t.shard(devices_4, 0).realize()
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, ts.lazydata.lbs), 0))
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, ts.lazydata.src), axis=0))
b.realize()
return aa, b
@@ -39,7 +39,7 @@ class TestMultiTensor(unittest.TestCase):
def test_to(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
for lb in X.lazydata.lbs:
for lb in X.lazydata.src:
assert lb.shape == (256,)
(X + X).realize()
@@ -52,7 +52,7 @@ class TestMultiTensor(unittest.TestCase):
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
for lb in X.lazydata.lbs:
for lb in X.lazydata.src:
assert lb.shape == (128,)
(X + X).realize()
@@ -218,9 +218,9 @@ class TestMultiTensor(unittest.TestCase):
shape = tuple([(n if i == 0 else 1) * random.randint(1, 10) for i in range(random.randint(1, 4))])
t = Tensor.rand(shape).shard_(tuple([d0, d1, d2, d3][:n]), 0)
with Context(RING=0):
a = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
a = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
with Context(RING=2):
b = Tensor(MultiLazyBuffer(all_reduce(Ops.ADD, t.lazydata.lbs), 0))
b = Tensor(UOp.multi(*all_reduce(Ops.ADD, t.lazydata.src), axis=0))
diff = a - b
mean_err = diff.reshape((prod(diff.shape),)).abs().mean().numpy()
max_err = diff.reshape((prod(diff.shape),)).abs().max().numpy()
@@ -344,9 +344,7 @@ class TestMultiTensor(unittest.TestCase):
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow")
def test_data_parallel_resnet(self):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
from resnet import ResNet18
from extra.models.resnet import ResNet18
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
fake_image_sharded = fake_image.shard(devices_2, axis=0)
@@ -356,16 +354,14 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(m): p.shard_(devices_2).realize()
GlobalCounters.reset()
shard_output = m(fake_image_sharded).log_softmax().realize()
assert shard_output.lazydata.lbs[0].shape == (1, 1000)
assert shard_output.lazydata.lbs[1].shape == (1, 1000)
assert shard_output.lazydata.src[0].shape == (1, 1000)
assert shard_output.lazydata.src[1].shape == (1, 1000)
shard_output_np = shard_output.numpy()
np.testing.assert_allclose(real_output, shard_output_np, atol=1e-6, rtol=1e-6)
@unittest.skipIf(CI and Device.DEFAULT in ("CUDA", "NV", "LLVM"), "slow, and flaky on LLVM")
def test_data_parallel_resnet_train_step(self):
import sys, pathlib
sys.path.append((pathlib.Path(__file__).parent.parent / "extra" / "models").as_posix())
from resnet import ResNet18
from extra.models.resnet import ResNet18
from tinygrad.nn.optim import LARS
fake_image = Tensor.rand((2, 3, 224//8, 224//8))
@@ -386,12 +382,35 @@ class TestMultiTensor(unittest.TestCase):
GlobalCounters.reset()
optimizer.zero_grad()
shard_output = m(fake_image_sharded).sparse_categorical_crossentropy(labels_sharded, label_smoothing=0.1)
assert shard_output.lazydata.axis is None
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
def test_assign_kv_cache_multi(self):
bsz, max_context = 2, 8
class Attn:
@TinyJit
def __call__(self, xk:Tensor, start_pos:UOp):
seqlen = xk.shape[1]
if not hasattr(self, "cache_k"):
self.cache_k = Tensor.zeros(bsz, max_context, 1, 1).shard(devices_2).contiguous().realize()
keys = self.cache_k.shrink((None, (0, start_pos), None, None)).cat(xk, dim=1).contiguous() if start_pos > 0 else xk
self.cache_k.assign(keys.pad((None,(0,max_context-start_pos-seqlen),None,None)).contiguous()).realize()
attn = Attn()
xk = Tensor.ones(bsz, 3, 1, 1).shard(devices_2).contiguous()
attn(xk, 0)
for i in range(3,6):
# copied from LLaMA
start_pos = Variable("start_pos", 1, max_context).bind(i)
xk = Tensor.ones(bsz, 1, 1, 1).shard(devices_2).contiguous()
attn(xk, start_pos)
out = attn.cache_k.flatten().numpy()
np.testing.assert_allclose(out, [1.,1.,1.,1.,1.,1.,0.,0.,1.,1.,1.,1.,1.,1.,0.,0.])
def test_multi_tensor_jit_param(self):
@TinyJit
def jf(a, b) -> Tensor:
@@ -532,13 +551,13 @@ class TestMultiTensor(unittest.TestCase):
t4 = t2.reshape((26, 105,))
for t in [t0, t1, t2, t3, t4]:
assert t.lazydata.axis == 1
np.testing.assert_allclose(t.numpy().flatten(), t0.numpy().flatten())
assert t.lazydata.axis == 1
# test shape-one axis
t5 = t4.reshape((26, 1, 105))
assert t5.lazydata.axis == 2
np.testing.assert_allclose(t.numpy().flatten(), t5.numpy().flatten())
assert t5.lazydata.axis == 2
# test split and rejoin to the right and reshape to the left
t5 = t0.reshape((2, 13, 3, 5, 7))
@@ -553,7 +572,7 @@ class TestMultiTensor(unittest.TestCase):
# test no left join
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7))
t0.reshape((26*15,7)).schedule()
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
@@ -588,6 +607,7 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises(AssertionError):
# don't allow assigns that change axes
t_none.assign(t_zero)
t_none.schedule()
def test_init_rand_with_multiple_devices_fail(self):
# init rand with multi device is not allowed
@@ -627,6 +647,16 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
def test_rand_like_from_alu(self):
# TODO: fix this, which will also fix multi device dropout
a = Tensor.ones(4, 4).shard(devices_2, axis=0)
with self.assertRaises(AssertionError):
(a + a).rand_like()
b = Tensor.empty(4, 4).shard(devices_2, axis=None)
with self.assertRaises(AssertionError):
(a + b).rand_like()
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
@@ -635,7 +665,7 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.device, t2.device)
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.lbs, t2.lazydata.lbs))
assert all(tlb.shape == t2lb.shape for tlb, t2lb in zip(t.lazydata.src, t2.lazydata.src))
def test_rand_like_none_shard(self):
t = Tensor.empty((16, 16)).shard(devices_2)
@@ -718,7 +748,7 @@ class TestMultiTensor(unittest.TestCase):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs])
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.src])
@unittest.skip("this is unreliable on OSX")
def test_clone(self):
@@ -774,25 +804,31 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
with self.assertRaises(AssertionError):
# sharded axis shrink on non-device boundry is not allowed
a = t.shrink(((0, 3), (0, 8)))
a.schedule()
with self.assertRaises(AssertionError):
# cannot shrink sharded and non-sharded axis at the same time
a = t.shrink(((0, 2), (2, 4)))
a.schedule()
a = t.shrink(((0, 2), (0, 8)))
a.schedule()
assert a.shape == (2, 8)
assert a.lazydata.real == [True, False, False, False]
assert a.lazydata.real == (True, False, False, False)
with self.assertRaises(AssertionError):
# cannot pad sharded and non-sharded axis at the same time
p = a.pad(((0, 6), (0, 1)))
p.schedule()
with self.assertRaises(AssertionError):
# can only pad to whole axis
p = a.pad(((1, 5), (0, 0)))
p.schedule()
p = a.pad(((0, 6), (0, 0)))
p.schedule()
assert p.shape == (8, 8)
assert p.lazydata.real == [True, True, True, True]
assert p.lazydata.real == (True, True, True, True)
@given(strat.sampled_from([dtypes.float, dtypes.int, dtypes.int64, dtypes.int16]))
def test_ops(self, dtype):
@@ -804,8 +840,8 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
a = t.shrink(((0+2*i,2+2*i),None))
b = Tensor(t.numpy()[0+2*i:2+2*i])
assert a.shape == b.shape == (2, 8)
assert a.lazydata.real == [i==j for j in range(4)]
np.testing.assert_allclose(a.numpy(), b.numpy())
assert a.lazydata.real == tuple(i==j for j in range(4))
# cast
np.testing.assert_allclose(a.float().numpy(), b.float().numpy())
@@ -859,24 +895,27 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_equal((a+a).numpy(), na+na)
np.testing.assert_equal((b+b).numpy(), nb+nb)
@unittest.skip("why didn't this work?")
def test_add_two_partitions(self):
t = Tensor.arange(64).reshape(8, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(4)], axis=0)
a = t.shrink(((2, 4), None))
b = t.shrink(((6, 8), None))
self.assertEqual(a.lazydata.real, [False, True, False, False])
self.assertEqual(b.lazydata.real, [False, False, False, True])
na = t.numpy()[2:4]
nb = t.numpy()[6:8]
np.testing.assert_equal(a.numpy(), na)
np.testing.assert_equal(b.numpy(), nb)
self.assertEqual(a.lazydata.real, (False, True, False, False))
self.assertEqual(b.lazydata.real, (False, False, False, True))
with self.assertRaises(AssertionError):
# cannot add directly
c = a + b
c.schedule()
c = a.pad(((2, 4), None)) + b.pad(((6, 0), None))
self.assertEqual(c.lazydata.real, [True, True, True, True])
c.realize()
self.assertEqual(c.lazydata.real, (True, True, True, True))
expected = np.concatenate([np.zeros_like(t.numpy()[0:2]), na, np.zeros_like(t.numpy()[4:6]), nb])
np.testing.assert_equal(c.numpy(), expected)
@@ -937,8 +976,9 @@ class TestBatchNorm(unittest.TestCase):
def __call__(self, x:Tensor):
bn_ts = []
for bound, bn in zip(x.lazydata.bounds, self.bns):
xi = x.shrink((bound, None, None, None))
each = x.shape[0]//len(self.bns)
for i, bn in enumerate(self.bns):
xi = x.shrink(((each*(i), each*(i+1)), None, None, None))
bni = bn(xi)
bn_ts.append(bni)
return bn_ts[0].cat(*bn_ts[1:])

View File

@@ -352,12 +352,15 @@ class TestOps(unittest.TestCase):
def test_cmp_le(self): self._test_cmp(lambda x,y: x<=y)
def test_cmp_ne_backwards(self):
# new grad zeroes these out
"""
t1 = torch.ones(4, requires_grad=True)
t2 = torch.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (t1 != t2).sum().backward)
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 != tt2).sum().backward)
"""
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt != 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)
@@ -365,12 +368,15 @@ class TestOps(unittest.TestCase):
np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5)
def test_cmp_lt_backwards(self):
# new grad zeroes these out
"""
t1 = torch.ones(4, requires_grad=True)
t2 = torch.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (t1 < t2).sum().backward)
tt1 = Tensor.ones(4, requires_grad=True)
tt2 = Tensor.ones(4, requires_grad=True)
self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward)
"""
tt = Tensor.randn(4, requires_grad=True)
(tt*(tt < 0)).sum().backward()
t = torch.tensor(tt.numpy(), requires_grad=True)

View File

@@ -14,9 +14,9 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, prod, all_same, temp
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import ScheduleItem, ScheduleContext, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -67,7 +67,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
@track_rewrites(named=True)
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, ScheduleContext())
def schedule_graph_rewrite(big_sink:UOp): return graph_rewrite(big_sink, remove_movement_ops+sym, {})
class TestSchedule(unittest.TestCase):
def test_basic_binop_fusion(self):
@@ -323,7 +323,7 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 15)]:
for optim, cnt in [(nn.optim.Adam, 18), (nn.optim.SGD, 11)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
@@ -609,6 +609,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2)
# multireduce spec
@unittest.skip("these two Tensors are the same")
def test_example_matmul(self):
x = Tensor.eye(64, requires_grad=True)
y = Tensor.eye(64, requires_grad=True)
@@ -618,6 +619,15 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_contig(self):
x = Tensor.eye(64, requires_grad=True).contiguous().realize()
y = Tensor.eye(64, requires_grad=True).contiguous().realize()
z = y.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_same(self):
x = Tensor.eye(64, requires_grad=True)
z = x.matmul(x).sum()
@@ -1050,7 +1060,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 13)
check_schedule(opt.schedule_step(), 14)
def test_sgd_conv_fuse(self):
with Tensor.train():
@@ -1060,7 +1070,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters(c1))
opt.zero_grad()
c1(img).relu().sum().backward()
check_schedule(opt.schedule_step(), 5)
check_schedule(opt.schedule_step(), 3)
def test_sgd_2convs_fuse(self):
with Tensor.train():
@@ -1071,7 +1081,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 8)
check_schedule(opt.schedule_step(), 7)
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
with Tensor.train():
@@ -1082,7 +1092,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
opt.zero_grad()
c2(c1(img).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 10)
check_schedule(opt.schedule_step(), 9)
def test_sgd_4convs_fuse(self):
with Tensor.train():
@@ -1095,7 +1105,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
check_schedule(opt.schedule_step(), 18)
check_schedule(opt.schedule_step(), 17)
def test_sgd_4convs_fuse_conv_bw(self):
with Tensor.train():
@@ -1108,7 +1118,7 @@ class TestSchedule(unittest.TestCase):
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
opt.zero_grad()
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 14)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
@@ -1844,7 +1854,7 @@ def run_tensor_ast(r:Tensor):
sink = UOp(Ops.STORE, src=(glbl, ShapeTracker.from_shape(r.lazydata.base.shape).to_uop(), r.lazydata.base)).sink()
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+load_buffers+view_left, bufs:=[output])
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+view_right)
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), (), ())
si = ScheduleItem(sink, tuple(x.buffer for x in bufs), ())
run_schedule([si])
return output.realized.as_buffer().cast(output.dtype.fmt, r.shape).tolist()
@@ -1940,19 +1950,19 @@ class TestSwizzle(unittest.TestCase):
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
@unittest.expectedFailure
@unittest.skip("this swizzle can't be decided after the ADD")
def test_swizzle_failure_permute(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(20, ('METAL', 65, dtypes.float)), src=()),
UOp(Ops.BUFFER, dtypes.float, arg=(20, 65), src=(UOp(Ops.DEVICE, arg="METAL"),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, ('METAL', 2925, dtypes.float)), src=()),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(8, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
@@ -1971,13 +1981,13 @@ class TestSwizzle(unittest.TestCase):
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x15,)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, ('METAL', 2925, dtypes.float)), src=()),
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(2, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
x10,)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float, arg=(4, ('METAL', 2925, dtypes.float)), src=()),
UOp(Ops.BUFFER, dtypes.float, arg=(4, 2925), src=(UOp(Ops.DEVICE, arg="METAL"),)),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
@@ -2269,6 +2279,54 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.empty(4).lazydata
check_schedule(a.clone(), 1, filter_sink=False)
# NOTE: moving copy before view might change this
def test_shrink_copy(self):
a = Tensor.arange(4)
view = a.shrink(((0, 2),))
b = view.clone()
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 2)
self.assertEqual(b.lazydata.size, 2)
self.assertListEqual(b.tolist(), [0, 1])
def test_expanded_copy(self):
a = Tensor.arange(2)
view = a.reshape(2, 1).expand(2, 2)
b = view.clone()
run_schedule(check_schedule(b, 2, filter_sink=False))
self.assertEqual(b.lazydata.base.buffer.size, 2)
self.assertEqual(b.lazydata.size, 4)
self.assertListEqual(b.tolist(), [[0, 0], [1, 1]])
def test_permuted_copy(self):
a = Tensor.arange(4)
b = a.reshape(2, 2).permute(1, 0)
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_on_disk(self):
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
b = a.reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
def test_permute_after_shrink(self):
a = Tensor.arange(5)
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
# NOTE: disk permute must come after COPY
# TODO: this is wrong because of the permute
@unittest.expectedFailure
def test_permute_after_shrink_on_disk(self):
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
b.realize()
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
class TestTensorUOpSpec(unittest.TestCase):
def test_const_must_be_unmasked(self):
a = Tensor.ones((4, 4)).pad((2, 2))
@@ -2377,6 +2435,17 @@ class TestContiguous(unittest.TestCase):
b = a.expand((4, 4)).contiguous().contiguous()
check_schedule(b, 1)
def test_view_does_not_realize(self):
a = Tensor.empty(4)
b = a.expand((4, 4))
check_schedule(b, 0)
self.assertEqual(b.lazydata.base.buffer.size, 4)
def test_contiguous_view_realizes(self):
a = Tensor.empty(4)
b = a.expand((4, 4)).contiguous()
check_schedule(b, 1)
self.assertEqual(b.lazydata.base.buffer.size, 16)
class TestUOpBecome(unittest.TestCase):
# the simplest case, if we create a new BUFFER for this UOp

View File

@@ -652,19 +652,26 @@ class TestZeroShapeTensor(unittest.TestCase):
def test_clone(self):
a = Tensor.rand(16, 16).realize()
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
b = a.clone()
np.testing.assert_allclose(a.numpy(), b.numpy())
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_clone_with_shrink(self):
a = Tensor.empty(16, 16)
self.assertIsNot(a.lazydata, a.clone().lazydata)
a = Tensor.rand(16, 16)
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
b = a.shrink(((2, 10), None))
self.assertIsNot(b.lazydata, b.clone().lazydata)
def test_clone_with_shrink_realized(self):
a = Tensor.rand(16, 16).realize()
b = a.shrink(((2, 10), None)).clone()
b.realize()
self.assertIsNot(a.lazydata.base.buffer, b.lazydata.base.buffer)
def test_clone_with_grad(self):
a = Tensor.rand(16, 16, requires_grad=True)
@@ -763,8 +770,8 @@ class TestTensorMetadata(unittest.TestCase):
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
def test_complex_backward(self):
x = Tensor.rand(3, requires_grad=True)
y = Tensor.rand(3, requires_grad=True)
x = Tensor.rand(3, requires_grad=True).realize()
y = Tensor.rand(3, requires_grad=True).realize()
out = (x.relu() * y.sigmoid()).sum()
self.assertEqual(out.lazydata.metadata.name, "sum")
out.backward()

View File

@@ -14,7 +14,7 @@ class TestGradient(unittest.TestCase):
def _test_one_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), [x])[x]
gx = compute_gradient(f(x), UOp.const(dtypes.float, 1.0), set([x]))[x]
gf = jax.grad(f if jf is None else jf)
for val in [-5., -2.0, 0.0, 2.0, 5.]:
@@ -24,7 +24,7 @@ class TestGradient(unittest.TestCase):
def _test_two_input_function(self, f:Callable, jf:Callable|None=None):
x = UOp.variable('x', -math.inf, math.inf, dtype=dtypes.float)
y = UOp.variable('y', -math.inf, math.inf, dtype=dtypes.float)
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), [x, y])
grads = compute_gradient(f(x, y), UOp.const(dtypes.float, 1.0), set([x, y]))
gx, gy = grads[x], grads[y]
gf = jax.grad(f if jf is None else jf, argnums=(0, 1))

View File

@@ -1,118 +1,104 @@
from typing import Dict, List, Optional
import unittest, decimal, json
from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher as PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, symbolic
from tinygrad.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json, to_perfetto
from tinygrad.viz.serve import get_metadata, uop_to_json, to_perfetto
@track_rewrites(named=True)
def rewrite(sink:UOp, pm:PatternMatcher, **kwargs): return graph_rewrite(sink, pm, **kwargs)
def helper_test_viz(sink:UOp, pm:PatternMatcher, **kwargs) -> List[UOp]:
rewrite(sink, pm, **kwargs)
assert len(contexts) == 1
assert len(contexts[0]) == 1
k = get_metadata(keys, contexts)[0][0]
g = get_details(*k)
return g.uops[1:]
# NOTE: VIZ tests always use the tracked PatternMatcher instance
symbolic = TrackedPatternMatcher(symbolic.patterns)
class TestViz(unittest.TestCase):
def setUp(self):
# clear the global context
contexts.clear()
keys.clear()
_name_cnt.clear()
self.tms = TRACK_MATCH_STATS.value
TRACK_MATCH_STATS.value = 2
def tearDown(self): TRACK_MATCH_STATS.value = self.tms
def test_viz_simple(self):
pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm)
self.assertEqual(len(uops), 1)
self.assertEqual(uops[0], a)
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
def test(sink): return graph_rewrite(sink, symbolic)
test(a*1)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
key, val = ret[0]
self.assertEqual(key, "test_1")
self.assertEqual(val[0]["match_count"], 1)
def test_rewrite_twice(self):
pm = PatternMatcher([
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(Ops.SHL, UOp.const(dtypes.int, 1))),
])
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm)
self.assertEqual(len(uops), 2)
self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
def test_track_two_rewrites(self):
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
def test(sink): return graph_rewrite(sink, symbolic)
test((a+a)*1)
ret = get_metadata(keys, contexts)
key, val = ret[0]
self.assertEqual(len(ret), 1) # one context
self.assertEqual(len(val), 1) # one graph_rewrite call in context
self.assertEqual(key, "test_1")
self.assertEqual(val[0]["match_count"], 2) # two upats applied
def test_rewrite_with_ctx(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), ShapeTracker.from_shape((1, 1)).to_uop()))
b = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), ShapeTracker.from_shape((1, 1)).to_uop()))
def store_load(ctx:Dict[UOp, None], glbl, st) -> Optional[UOp]:
if glbl in ctx: return None
ctx[glbl] = None
return UOp.store(glbl, ShapeTracker.from_shape(st.shape).to_uop())
pm = PatternMatcher([
(UPat.load(UPat(Ops.DEFINE_GLOBAL, name="glbl"), UPat.var("st")), store_load),
])
uops = helper_test_viz(a+b, pm, ctx={})
self.assertEqual(len(uops), 2)
self.assertEqual(uops[-1], graph_rewrite(a+b, pm, {}))
def test_track_multiple_calls_one_ctx(self):
a = UOp.variable("a", 0, 10)
@track_rewrites(named=True)
def test(a, b):
a = graph_rewrite(a, symbolic)
b = graph_rewrite(b, symbolic)
test(a*1, a*5)
ret = get_metadata(keys, contexts)
key, val = ret[0]
self.assertEqual(len(ret), 1) # one context
self.assertEqual(len(val), 2) # two graph_rewrite calls in context
self.assertEqual(key, "test_1")
self.assertEqual(val[0]["match_count"], 1) # one rewrite for a*0
self.assertEqual(val[1]["match_count"], 0) # no rewrites for a*5
def test_track_rewrites(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites(named=True)
def do_rewrite(x:UOp): return graph_rewrite(x, simple)
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
do_rewrite(ld*1)
do_rewrite(ld*2)
def do_rewrite(x:UOp): return graph_rewrite(x, symbolic)
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 4)
do_rewrite(a*1)
do_rewrite(a*b)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 2)
key, _, m = ret[0][0]
key, m = ret[0]
self.assertEqual(key, "do_rewrite_1")
self.assertEqual(len(m.upats), 1)
key, _, m = ret[1][0]
self.assertEqual(m[0]["match_count"], 1)
key, m = ret[1]
self.assertEqual(key, "do_rewrite_2")
self.assertEqual(len(m.upats), 0)
self.assertEqual(m[0]["match_count"], 0)
def test_track_rewrites_with_exception(self):
simple = PatternMatcher([(UPat.var("x")*1, lambda x:x)])
@track_rewrites()
def do_rewrite(x:UOp):
x = graph_rewrite(x, simple) # NOTE: viz tracks this
x = graph_rewrite(x, symbolic) # NOTE: viz tracks this
raise Exception("test")
ld = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp.const(dtypes.int, 0)))
with self.assertRaises(Exception): do_rewrite(ld*1)
a = UOp.variable("a", 0, 10)
with self.assertRaises(Exception): do_rewrite(a*1)
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
def test_fold_const(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
# NOTE: CONST UOps do not get nodes in the graph
def test_dont_create_const_nodes(self):
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 4)
self.assertEqual(len(uop_to_json(a*1)), 2)
self.assertEqual(len(uop_to_json(a*b)), 3)
@unittest.skip("TODO: bring this back with better testing")
def test_bottom_up_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True)
self.assertEqual(len(ret), 2)
self.assertIs(ret[0], a.sin().sqrt()) # first rewrite
self.assertIs(ret[1], a.sqrt().sqrt()) # second one
def test_top_down_rewrite(self):
a = UOp(Ops.LOAD, dtypes.int, (UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
n1 = a.sin()
uop = n1.sin()
pm = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
# if it wasn't bottom_up, it's rewritten once
ret = helper_test_viz(uop, pm, ctx={a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=False)
a = UOp.variable("a", 0, 10)
b = UOp.variable("b", 0, 10)
c = UOp.variable("c", 0, 10)
UOp.substitute(a+b, {a+b:c})
ret = get_metadata(keys, contexts)
self.assertEqual(len(ret), 1)
self.assertIs(ret[0], a.sqrt().sin()) # only rewrite
_, _, vals = ret[0][0]
self.assertEqual(len(vals.upats), 1)
# NOTE: calling graph_rewrite when the function isn't decorated with track_rewrites should not VIZ
def test_rewrite_without_context(self):
@@ -211,6 +197,5 @@ class TextVizProfiler(unittest.TestCase):
self.assertEqual(j['traceEvents'][7]['dur'], 4)
self.assertEqual(j['traceEvents'][7]['pid'], j['traceEvents'][3]['pid'])
if __name__ == "__main__":
unittest.main()

View File

@@ -3,14 +3,14 @@ from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Any, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution
cpu_time_execution, colored, Context
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
# **************** Device ****************
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"]
class _Device:
def __init__(self) -> None:
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
@@ -25,7 +25,7 @@ class _Device:
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{__name__.split(".")[0]}.runtime.ops_{x.lower()}')) \
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) \
if (cname.lower() == x.lower() + "device")][0](ix)
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
self._opened_devices.add(ix)
@@ -33,7 +33,7 @@ class _Device:
@property
def default(self) -> Compiled: return self[self.DEFAULT]
def get_available_devices(self) -> Iterator[str]:
for device in ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM"]:
for device in ALL_DEVICES:
with contextlib.suppress(Exception): yield self[device].device
@functools.cached_property
def DEFAULT(self) -> str:
@@ -220,9 +220,10 @@ MAP_JIT = 0x0800
# CPUProgram is a jit/shellcode program that can be just mmapped and jumped to
class CPUProgram:
helper_handle = ctypes.CDLL(ctypes.util.find_library('System') if OSX else 'libgcc_s.so.1')
helper_handle = ctypes.CDLL(ctypes.util.find_library('System' if OSX else 'kernel32' if sys.platform == "win32" else 'gcc_s'))
def __init__(self, name:str, lib:bytes):
assert sys.platform != "win32", "clang is not supported for windows yet"
from mmap import mmap, PROT_READ, PROT_WRITE, PROT_EXEC, MAP_ANON, MAP_PRIVATE
# On apple silicon with SPRR enabled (it always is in macos) RWX pages are unrepresentable: https://blog.svenpeter.dev/posts/m1_sprr_gxf/
# MAP_JIT allows us to easily flip pages from RW- to R-X and vice versa. It is a noop on intel cpus. (man pthread_jit_write_protect_np)
self.mem = mmap(-1, len(lib), MAP_ANON | MAP_PRIVATE | (MAP_JIT if OSX else 0), PROT_READ | PROT_WRITE | PROT_EXEC)
@@ -314,3 +315,18 @@ if PROFILE:
from tinygrad.ops import launch_viz
launch_viz("PROFILE", fn)
if __name__ == "__main__":
for device in ALL_DEVICES:
try:
_ = Device[device].device
try:
from tinygrad import Tensor
with Context(CACHELEVEL=0): test = (Tensor([1,2,3], device=device) * 2).tolist()
if test != [2,4,6]: raise ValueError(f"got {test} instead of [2, 4, 6]")
result = colored("PASS", "green")
except Exception as e:
result = f"{colored('FAIL', 'yellow')} {e}"
except Exception as e:
result = f"{colored('FAIL', 'red')} {e}"
print(f"{'*' if device == Device.DEFAULT else ' '} {device:10s}: {result}")

View File

@@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, BEAM, getenv, colored, JIT, dedup, partition, unwrap
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, Variable, sym_infer
from tinygrad.ops import UOp, Variable, sym_infer, Ops
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner, Estimates
from tinygrad.engine.memory import _internal_memory_planner
@@ -194,7 +194,8 @@ def _prepare_jit_inputs(args, kwargs):
input_tensors: list[tuple[int|str, Tensor]] = [(name,t) for name,t in list(enumerate(args))+sorted(kwargs.items()) if t.__class__ is Tensor]
names, tensors = [name for name,_ in input_tensors], [t for _,t in input_tensors]
if tensors: Tensor.realize(*tensors)
lbs: list[UOp] = flatten([t.lazydata.lbs for t in tensors])
# TODO: should we be unpacking multi here?
lbs: list[UOp] = flatten([t.lazydata.src if t.lazydata.op is Ops.MULTI else [t.lazydata] for t in tensors])
input_buffers: list[Buffer] = [lb.base.realized for lb in lbs if lb.base.realized is not None]
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
st_varval_dtype_device = [(*unwrap(lb.st).unbind(), lb.dtype, lb.device) for lb in lbs]

View File

@@ -47,4 +47,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.assign_preloads) for si in schedule]
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]

View File

@@ -1,10 +1,10 @@
import sys, atexit, functools, pickle
from collections import defaultdict, deque
from dataclasses import dataclass, field
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element, buffers, symbolic_simple, type_verify, graph_rewrite_map
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY, ContextVar
from tinygrad.ops import UOp, Variable, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, track_rewrites, type_verify, buffers
from tinygrad.ops import can_pad, identity_element, resolve, symbolic_simple, view_left, merge_views
from tinygrad.helpers import Context, ContextVar, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, CAPTURE_PROCESS_REPLAY
from tinygrad.dtype import DType, ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -37,7 +37,7 @@ tensor_uop_spec = PatternMatcher([
# DETACH and CONTIGUOUS change how we interpret the source UOp
# CONTIGUOUS ensures the source UOp realizes
(UPat((Ops.DETACH, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
(UPat((Ops.DETACH, Ops.CONTIGUOUS, Ops.CONTIGUOUS_BACKWARD), name="root", src=(UPat.var("x"),), arg=None), lambda root,x: root.dtype == x.dtype),
# COPY
# NOTE: the arg here specifies clone=True, which prevents folding same device copy
@@ -60,7 +60,6 @@ class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
assign_preloads: tuple[UOp, ...]
@property
def outputs(self) -> tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
@@ -82,9 +81,8 @@ class ScheduleContext:
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
ops_metadata: dict[UOp, Metadata] = field(default_factory=dict) # this maps fused ops to Metadata
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
becomes_map: dict[UOp, UOp] = field(default_factory=dict)
preloads: defaultdict[Buffer, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
@@ -104,6 +102,7 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c
if DEBUG >= 2: print(f"forcing image {dtype} with shape {buf.shape} to {dtype.base}")
dtype = buf.dtype.base
# ASSIGN already has a target buffer, otherwise we create a new one
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
@@ -152,9 +151,9 @@ def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
# push VIEW to stores
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> STORE(.., new_val).view()
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
@@ -200,6 +199,8 @@ to_si = PatternMatcher([
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# CONST(VIEW) becomes VALID too, TODO: doesn't have to
(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: x.replace(src=()).valid(st.st)),
])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
@@ -211,25 +212,27 @@ def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
# deal with ASSIGN
assign_preloads: list[UOp] = []
if len(ctx.assigns) != 0:
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads.append(x.buf_uop)
assign_preloads[x.buf_uop] = None
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
# otherwise, it's not fine
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -329,7 +332,7 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
# maybe fuse arange with its children
for rbuf in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf}
if any(luop.op is Ops.CONTIGUOUS for tr in group for luop in ctx.tensor_uops[tr]): continue
if any(tensor_uop.op is Ops.CONTIGUOUS for tr in group for tensor_uop in ctx.tensor_uops[tr]): continue
kernel_children = {c for tr in group for c in ctx.children[tr] if uval(ctx.allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group: del ctx.realizes[tr]
@@ -354,20 +357,20 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
case _: return None
return reduce.const_like(ret)
def found_contiguous(ctx:ScheduleContext, contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx.contiguous[src.base] = contig.view(sti)
def replace_contiguous(ctx:ScheduleContext, alu:UOp):
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.contiguous.get(s, None)) is not None: new_src[i] = replace_src
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH is a NOOP here
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
@@ -389,10 +392,10 @@ sym = symbolic_simple+PatternMatcher([
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER/VIEW from SINK
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x.base for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
# ** this decides which ops get realized
@@ -419,7 +422,7 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, x:UOp, **kwargs)
return x.view(unwrap(view.st))
def create_subbuffer(base:UOp, b:UOp, root:UOp, x:UOp):
if not b.device.startswith("DISK"): return None
if isinstance(b.device, tuple) or not b.device.startswith("DISK"): return None
buffers[b] = x.buf_uop.buffer.view(b.size, b.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
return base.replace(src=(b, root.replace(op=Ops.BUFFER_VIEW)))
@@ -440,12 +443,12 @@ do_realize = PatternMatcher([
(UPatScheduled((Ops.BITCAST, Ops.CONTIGUOUS), name="root", src=(UPat.var("x"),)), create_subbuffer),
])
# **** rewrite VIEW into LOAD/STORE/VALID or fuse the underlying UOp
# **** rewrite VIEW into LOAD/STORE or fuse the underlying UOp
def unbind_variable(ctx:ScheduleContext, bind:UOp, var:UOp, val:UOp):
assert isinstance(val.src[1].const_arg, int), f"expected BIND value to be int {val}"
ctx.var_vals[ret:=var.replace(src=())] = val.src[1].const_arg
return ret.valid(unwrap(bind.st))
assert isinstance(val.const_arg, int), f"expected BIND value to be int {val}"
ctx.var_vals[var.replace(src=())] = val.const_arg
return var
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
@@ -458,8 +461,6 @@ def store_or_fuse(ctx:ScheduleContext, b:UOp, x:UOp, st:UOp):
return UOp(Ops.LOAD, x.dtype, (b, unwrap(st.st).to_uop()))
break_sched = PatternMatcher([
# CONST is always fused and generated
(UPat(Ops.CONST, name="x", src=(UPat(Ops.VIEW, name="st"),)), lambda x,st: UOp.const(x.dtype, x.const_arg).valid(st.st)),
(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.var("val"))), unbind_variable),
# VIEW of BUFFER either becomes a LOAD/STORE or we fuse it
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"),)), load_realized),
@@ -473,38 +474,43 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
buf_uop.buffer.ref(1)
create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf_uop"), UPat())), append_uop)])
# **** movement ops
remove_movement_ops = PatternMatcher([
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
# merge one src views.
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
# merge unmasked const views
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
])
@track_rewrites(named=True)
def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
if not skip_check: type_verify(list(big_sink.toposort), tensor_uop_spec)
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx:=ScheduleContext())
tensor_map = graph_rewrite_map(big_sink, remove_movement_ops+sym, ctx={})
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: becomes_map[k] = v
# we group the rest of UOps into ScheduleItems
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx, cache={})
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
# add realizes
sink = graph_rewrite(sink, do_realize+create_ctx, ctx)
# group realizes into kernels
store_groups = group_realizes(ctx)
graph_rewrite(sink, break_sched, ctx)
# preschedule realize groups
# create schedule items + map buffers to realized tensors
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
small_sink = UOp.sink(*[ctx.realizes[u] for u in store_uops])
@@ -512,16 +518,9 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
prescheduled.append(schedule_uop(small_sink, ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# tensors can become an existing buffer or simplify to a const, no ScheduleItem needed
for k,v in tensor_map.items():
# NOOP
if k.base is v.base: continue
# NOTE: only the base tensors get a BUFFER UOp
if v.is_realized and k is k.base: ctx.becomes_map[k] = v.view(unwrap(k.st))
# otherwise if it simplified to a CONST the UOp just becomes that CONST
elif v.op is Ops.CONST: ctx.becomes_map[k] = v
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
# increment refcount for this buffer
buf_uop.buffer.ref(1)
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
@@ -529,7 +528,7 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
for si in prescheduled:
# realize outputs before a parent is assigned to
parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
parents_assigns = dedup(xsi for x in ctx.preloads[si.bufs[0]] if (xsi:=schedule_targets.get(x.buffer)) and xsi is not si)
for assign in parents_assigns:
graph[si].append(assign)
in_degree[assign] += 1
@@ -550,4 +549,4 @@ def create_schedule_with_vars(big_sink:UOp, skip_check:bool=not __debug__) -> tu
# confirm everything was scheduled correctly
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
return schedule, ctx.var_vals, ctx.becomes_map
return schedule, ctx.var_vals, becomes_map

View File

@@ -1,203 +0,0 @@
"""This is where the forwards and backwards passes live."""
import math
from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import Ops, resolve, sint, UOp
from tinygrad.tensor import Function
class Contiguous(Function):
def forward(self, x:UOp) -> UOp: return x.contiguous()
def backward(self, grad_output:UOp) -> UOp: return grad_output
class ContiguousBackward(Function):
def forward(self, x:UOp) -> UOp: return x
def backward(self, grad_output:UOp) -> UOp: return grad_output.contiguous()
class Cast(Function):
def forward(self, x:UOp, dtype:DType, bitcast:bool=False) -> UOp:
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.bitcast(dtype) if self.bitcast else x.cast(dtype)
def backward(self, grad_output:UOp) -> UOp:
if self.bitcast: raise RuntimeError("bitcast cannot backward")
return grad_output.cast(self.input_dtype)
# ************* unary ops *************
class Reciprocal(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.reciprocal()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return -grad_output * self.ret * self.ret
class Sin(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.sin()
def backward(self, grad_output:UOp) -> UOp: return (math.pi/2 - self.x).sin() * grad_output
class Relu(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x>0).where(x, 0)
return self.ret
def backward(self, grad_output:UOp) -> UOp: return (self.ret>0).cast(grad_output.dtype) * grad_output
class Log(Function):
def forward(self, x:UOp) -> UOp:
self.x = x
return x.log2() * math.log(2)
def backward(self, grad_output:UOp) -> UOp: return grad_output / self.x
class Exp(Function):
def forward(self, x:UOp) -> UOp:
self.ret = (x * (1/math.log(2))).exp2()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return self.ret * grad_output
class Sqrt(Function):
def forward(self, x:UOp) -> UOp:
self.ret = x.sqrt()
return self.ret
def backward(self, grad_output:UOp) -> UOp: return grad_output / (self.ret*2)
class Sign(Function):
# NOTE: the x*0 is to match torch behavior without function.py
def forward(self, x:UOp) -> UOp: return x.ne(0).where((x<0).where(x.const_like(-1), x.const_like(1)), x.const_like(0)) + x*0
# backward always return 0 to match torch
def backward(self, grad_output:UOp) -> UOp: return grad_output.const_like(0)
# ************* binary ops *************
class Less(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x<y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Neq(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x.ne(y)
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]: return None, None
class Xor(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x^y
class BitwiseAnd(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x&y
class BitwiseOr(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x|y
class Threefry(Function):
def forward(self, x:UOp, seed:UOp) -> UOp: return x.threefry(seed)
class Add(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x+y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x:UOp, y:UOp) -> UOp:
self.x, self.y = x, y
return x * y
def backward(self, grad_output:UOp) -> tuple[UOp|None, UOp|None]:
return (self.y * grad_output) if self.needs_input_grad[0] else None, \
(self.x * grad_output) if self.needs_input_grad[1] else None
class IDiv(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
class Mod(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
# ************* ternary ops *************
class Where(Function):
def forward(self, x:UOp, y:UOp, z:UOp) -> UOp:
self.x = x
return self.x.where(y, z)
def backward(self, grad_output:UOp) -> tuple[None, UOp|None, UOp|None]:
return None, \
self.x.where(grad_output, grad_output.const_like(0)) if self.needs_input_grad[1] else None, \
self.x.where(grad_output.const_like(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************
class Sum(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.r(Ops.ADD, axis)
def backward(self, grad_output:UOp) -> UOp: return grad_output.expand(self.input_shape)
class Prod(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret = x, x.r(Ops.MUL, axis)
return self.ret
def backward(self, grad_output:UOp) -> UOp:
return (grad_output * self.ret).expand(self.x.shape) / self.x
class Max(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.x, self.ret, self.axis = x, x.r(Ops.MAX, axis), axis
return self.ret
def backward(self, grad_output:UOp) -> UOp:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.ne(self.ret.expand(self.x.shape)).ne(self.x.const_like(1).cast(dtypes.bool)).cast(grad_output.dtype)
div = max_is_1s.r(Ops.ADD, self.axis).expand(self.x.shape)
return (max_is_1s/div) * grad_output.expand(self.x.shape)
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if resolve(si != so))
return x.expand(shape)
def backward(self, grad_output:UOp) -> UOp:
return grad_output.cast(sum_acc_dtype(grad_output.dtype)).r(Ops.ADD, self.expanded_axis).cast(grad_output.dtype)
class Reshape(Function):
def forward(self, x:UOp, shape:tuple[int, ...]) -> UOp:
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:UOp) -> UOp: return grad_output.reshape(self.input_shape)
class Permute(Function):
def forward(self, x:UOp, order:tuple[int, ...]) -> UOp:
self.input_order = order
return x.permute(order)
def backward(self, grad_output:UOp) -> UOp: return grad_output.permute(argsort(self.input_order))
class Pad(Function):
def forward(self, x:UOp, arg:tuple[tuple[int, int], ...]) -> UOp:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.shrink(self.narg)
class Shrink(Function):
def forward(self, x:UOp, arg:tuple[tuple[sint, sint], ...]) -> UOp:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.pad(self.narg)
class Flip(Function):
def forward(self, x:UOp, axis:tuple[int, ...]) -> UOp:
self.arg = tuple([-1 if i in axis else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
def backward(self, grad_output:UOp) -> UOp: return grad_output.stride(self.arg)

View File

@@ -1,7 +1,7 @@
from typing import cast, Iterator
import math, functools
import math, functools, dataclasses
from tinygrad.dtype import dtypes, sum_acc_dtype
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, all_metadata
from tinygrad.helpers import argsort
def reduce_gradient(ctx:UOp, ret:UOp):
@@ -28,6 +28,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.WHERE, name="ret"), lambda ctx, ret: (None, ret.src[0].where(ctx, ctx.const_like(0)), ret.src[0].where(ctx.const_like(0), ctx))),
(UPat(Ops.REDUCE_AXIS, name="ret"), reduce_gradient),
(UPat(Ops.CONTIGUOUS), lambda ctx: (ctx,)),
(UPat(Ops.CONTIGUOUS_BACKWARD), lambda ctx: (ctx.contiguous(),)),
(UPat(Ops.RESHAPE, name="ret"), lambda ctx, ret: (ctx.reshape(ret.src[0].shape),)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.arg)),)),
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.arg)])),)),
@@ -36,7 +37,7 @@ pm_gradient = PatternMatcher([
# TODO: this cast can be removed by putting the casts around the EXPAND
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# there's no gradient for bitcast
(UPat(Ops.BITCAST), lambda ctx: (None,)),
])
@@ -65,4 +66,5 @@ def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp
if v is None: continue
if k in grads: grads[k] = grads[k] + v
else: grads[k] = v
if (forward_metadata:=all_metadata.get(t0)) is not None: all_metadata[v] = dataclasses.replace(forward_metadata, backward=True)
return grads

View File

@@ -109,6 +109,7 @@ USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT",
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL = ContextVar("CACHELEVEL", 2)
@dataclass(frozen=True)
class Metadata:
@@ -165,7 +166,6 @@ class Profiling(contextlib.ContextDecorator):
cache_dir: str = os.path.join(getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")), "tinygrad")
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(cache_dir, "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 17
_db_connection = None
@@ -186,7 +186,7 @@ def diskcache_clear():
cur.executescript("\n".join([s[0] for s in drop_tables] + ["VACUUM;"]))
def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
if CACHELEVEL == 0: return None
if CACHELEVEL < 1: return None
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()
@@ -199,7 +199,7 @@ def diskcache_get(table:str, key:Union[dict, str, int]) -> Any:
_db_tables = set()
def diskcache_put(table:str, key:Union[dict, str, int], val:Any, prepickled=False):
if CACHELEVEL == 0: return val
if CACHELEVEL < 1: return val
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()

View File

@@ -1,8 +1,7 @@
from __future__ import annotations
import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.dtype import DType
from tinygrad.ops import Ops, MathTrait, UOp, sint
from tinygrad.ops import Ops, UOp, sint
def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
assert all_int(lbs[0].shape), f"does not support symbolic shape {lbs[0].shape}"
@@ -40,133 +39,127 @@ def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) ->
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
class MultiLazyBuffer(MathTrait):
def __init__(self, lbs:list[UOp], axis:int|None, real:list[bool]|None=None):
assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
# ***** multi functions *****
@property
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
from tinygrad.ops import PatternMatcher, UPat, GroupOp, graph_rewrite_map, track_rewrites
@property
def size(self): return sum(x.size for x in self.real_lbs)
def alu_multi(root:UOp):
msrcs = root.src
assert all(x.op is Ops.MULTI for x in msrcs), f"all buffers must be MultiLazyBuffer {[x.op for x in msrcs]}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
@property
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
# NOTE: they all have to share an axis, we always choose [-1]
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
srcs:list[list[UOp]] = []
not_all_real = not all(all(mlb.real) for mlb in msrcs)
new_real = tuple(all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])) if not_all_real else msrcs[0].real
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(list(mlb.src))
else:
assert axis is not None and bounds is not None
if mlb.axis is None: srcs.append(to_sharded(list(mlb.src), axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.src], axis, bounds))
new_lbs = [lsrcs[0].alu(root.op, *lsrcs[1:]) for lsrcs in zip(*srcs)]
new_lbs = [x if r else x.const_like(0) for r,x in zip(new_real, new_lbs)] # TODO: is this needed?
return UOp.multi(*new_lbs, axis=axis, real=new_real)
@property
def bounds(self):
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0)))
def reduce_multi(root:UOp, multi:UOp):
op, axis = root.arg
if multi.axis is not None and multi.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(multi.src, multi.real)]
# if all partitions are real, do all_reduce
if all(multi.real): return UOp.multi(*all_reduce(op, reduced_parts), axis=None)
# only one partition is real, keep it
return UOp.multi(*reduced_parts, axis=None, real=multi.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return UOp.multi(*[x.r(op, axis) for x in multi.src], axis=multi.axis, real=multi.real)
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
def _shape_to_single_shard(axis, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[axis] if a == axis else s for a,s in enumerate(shape))
def copy_to_device(self, device:str) -> UOp:
# if we already have a copy on the device, return that
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
if not real: continue
pad_arg = tuple((0,0) if a != self.axis else (start, self.bounds[-1][1]-end) for a in range(len(lb.shape)))
llbs.append(lb.copy_to_device(device).pad(pad_arg))
return functools.reduce(operator.add, llbs)
def reshape_multi(root:UOp, multi:UOp):
arg = root.arg
if multi.axis is None: return UOp.multi(*[x.reshape(arg) for x in multi.src], axis=None, real=multi.real)
assert prod(multi.shape) == prod(arg), "reshape must maintain prod(shape)"
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(multi.shape[:multi.axis])) - 1
assert all(prod(lb.shape[multi.axis:])%prod(arg[new_axis+1:])==0 for lb in multi.src), \
f"reshape cannot move items between shards {multi.shape} -> {root.arg=}"
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[multi.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in multi.src]
return UOp.multi(*lbs, axis=new_axis, real=multi.real)
# passthroughs
@property
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
def cast(self, dtype:DType): return MultiLazyBuffer([x.cast(dtype) for x in self.lbs], self.axis, self.real)
def bitcast(self, dtype:DType): return MultiLazyBuffer([x.bitcast(dtype) for x in self.lbs], self.axis, self.real)
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real)
@property
def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort}
def expand_multi(root:UOp, multi:UOp):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert multi.axis is None or root.arg[multi.axis] == multi.shape[multi.axis], f"expand not supported on sharded axis {root.arg=}"
return UOp.multi(*[x.expand(_shape_to_single_shard(multi.axis, root.arg, x)) for x in multi.src], axis=multi.axis, real=multi.real)
# elementwise is simple
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
msrcs = (self,)+in_srcs
assert all(isinstance(x, MultiLazyBuffer) for x in msrcs), f"all buffers must be MultiLazyBuffer {msrcs}"
assert all_same([x.device for x in msrcs]), f"all buffers must have the same device {[x.device for x in msrcs]}"
def pad_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0,0) or not all(multi.real), f"padding not supported for {root.arg=}"
# pad on shard axis -> fill others with zeros and set real to all True
if multi.axis is not None and root.arg[multi.axis] != (0,0):
# pad back to whole axis, remove real mask
assert all(root.arg[i] == (0, 0) for i in range(len(multi.shape)) if i != multi.axis), "cannot pad sharded and non-sharded axis at the same time"
dim, bound = sum(lb.shape[multi.axis] for lb in multi.src), multi.bounds[multi.real.index(True)]
assert root.arg[multi.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
return UOp.multi(*[x if r else x.const_like(0) for x,r in zip(multi.src, multi.real)], axis=multi.axis)
return UOp.multi(*[x.pad(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
# NOTE: they all have to share an axis, we always choose [-1]
axis, bounds = axes[-1] if len(axes := dedup([(x.axis, x.bounds) for x in msrcs if x.axis is not None])) else (None, None)
srcs:list[list[UOp]] = []
not_all_real = not all(all(mlb.real) for mlb in msrcs)
new_real = [all(transposed) for transposed in zip(*[mlb.real for mlb in msrcs])] if not_all_real else self.real
assert any(new_real), "output contains no real lb"
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
else:
assert axis is not None and bounds is not None
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
new_dtype = next(iter(new_real_lbs.values())).dtype
return MultiLazyBuffer([new_real_lbs.get(i, lsrcs[0].const_like(0).cast(new_dtype)) for i,lsrcs in enumerate(zip(*srcs))], axis, new_real)
def permute_multi(root:UOp, multi:UOp):
# all permutes supported!
return UOp.multi(*[x.permute(root.arg) for x in multi.src], axis=root.arg.index(multi.axis) if multi.axis is not None else None, real=multi.real)
def r(self, op:Ops, axis:tuple[int, ...]) -> MultiLazyBuffer:
if self.axis is not None and self.axis in axis:
# all-reduce on sharded axes
reduced_parts = [(x if r else x.const_like(0)).r(op, axis) for x,r in zip(self.lbs, self.real)]
# if all partitions are real, do all_reduce
if all(self.real): return MultiLazyBuffer(all_reduce(op, reduced_parts), None)
# only one partition is real, keep it
return MultiLazyBuffer(reduced_parts, None, self.real)
# reduce on non sharded axes, piecewise is fine. if axis is None this is also correct
return MultiLazyBuffer([x.r(op, axis) for x in self.lbs], self.axis, self.real)
def shrink_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == (0, multi.shape[multi.axis]) or root.arg[multi.axis] in multi.bounds, \
f"shrinking not supported for {root.arg=}"
if multi.axis is not None and root.arg[multi.axis] in multi.bounds and root.arg[multi.axis] != (0, multi.shape[multi.axis]):
assert all(root.arg[i] == (0, s) or i == multi.axis for i,s in enumerate(multi.shape)), \
"cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
idx = multi.bounds.index(root.arg[multi.axis])
# zero out other lbs to not create lb reference
return UOp.multi(*[lb if i==idx else lb.const_like(0) for i,lb in enumerate(multi.src)],
axis=multi.axis, real=tuple(i==idx for i in range(len(multi.src))))
return UOp.multi(*[x.shrink(tuple((0, x.shape[multi.axis]) if a == multi.axis else s for a,s in enumerate(root.arg))) for x in multi.src],
axis=multi.axis, real=multi.real)
# *** movement ops ***
def stride_multi(root:UOp, multi:UOp):
assert multi.axis is None or root.arg[multi.axis] == 1, "flipping not supported on sharded axis"
return UOp.multi(*[x.stride(root.arg) for x in multi.src], axis=multi.axis, real=multi.real)
def _shape_to_single_shard(self, shape:tuple[sint, ...], lb:UOp) -> tuple[sint, ...]:
return tuple(lb.shape[self.axis] if a == self.axis else s for a,s in enumerate(shape))
def copy_multi(multi:UOp, device:UOp):
# if we already have a copy on the device, return that
if multi.axis is None: return next((lb for lb in multi.real_lbs if lb.device == device.arg), multi.real_lbs[0].copy_to_device(device.arg))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(multi.src, multi.real, multi.bounds):
if not real: continue
pad_arg = tuple((0,0) if a != multi.axis else (start, multi.bounds[-1][1]-end) for a in range(len(lb.shape)))
llbs.append(lb.copy_to_device(device.arg).pad(pad_arg))
return functools.reduce(operator.add, llbs)
def reshape(self, arg:tuple[sint, ...]):
if self.axis is None: return MultiLazyBuffer([x.reshape(arg) for x in self.lbs], None, self.real)
assert prod(self.shape) == prod(arg), "reshape must maintain prod(shape)"
arg_acc:list[sint] = list(itertools.accumulate(arg, operator.mul, initial=1))
# new_axis is the last one that preserves prod(prior to new_axis) and must not move items between shards
# todo: what to do about shrinking to self.shape[self.axis]==1 len(self.real_lbs)==1?
new_axis = len(arg_acc) - arg_acc[::-1].index(prod(self.shape[:self.axis])) - 1
assert all(prod(lb.shape[self.axis:])%prod(arg[new_axis+1:])==0 for lb in self.lbs), f"reshape cannot move items between shards {self=} {arg=}"
lbs = [x.reshape(tuple(s if a!=new_axis else prod(x.shape[self.axis:])//prod(arg[new_axis+1:]) for a,s in enumerate(arg))) for x in self.lbs]
return MultiLazyBuffer(lbs, new_axis, self.real)
def assign_multi(dest:UOp, src:UOp):
assert dest.axis == src.axis and dest.real == src.real, f"axis/real must match in assign {dest.axis} != {src.axis} or {dest.real} != {src.real}"
return UOp.multi(*[x.assign(y) for x,y in zip(dest.src, src.src)], axis=src.axis, real=src.real)
def pad(self, arg:tuple[tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0,0) or not all(self.real), f"padding not supported for {arg=}"
# pad on shard axis -> fill others with zeros and set real to all True
if self.axis is not None and arg[self.axis] != (0,0):
# pad back to whole axis, remove real mask
assert all(arg[i] == (0, 0) for i in range(len(self.shape)) if i != self.axis), "cannot pad sharded and non-sharded axis at the same time"
dim, bound = sum(lb.shape[self.axis] for lb in self.lbs), self.bounds[self.real.index(True)]
assert arg[self.axis] == (bound[0], dim-bound[1]), "can only pad to whole axis"
return MultiLazyBuffer([x if r else x.const_like(0) for x,r in zip(self.lbs, self.real)], self.axis)
return MultiLazyBuffer([x.pad(arg) for x in self.lbs], self.axis, self.real)
def passthrough_multi(root:UOp, multi:UOp): return UOp.multi(*[root.replace(src=(m,)) for m in multi.src], axis=multi.axis, real=multi.real)
def expand(self, arg:tuple[sint, ...]):
# NOTE: this assert isn't needed, sharded axis can have dim 1
assert self.axis is None or arg[self.axis] == self.shape[self.axis], f"expand not supported on sharded axis {arg=}"
return MultiLazyBuffer([x.expand(self._shape_to_single_shard(arg, x)) for x in self.lbs], self.axis, self.real)
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([
(UPat(GroupOp.ALU, name="root", custom_early_reject=set([Ops.MULTI])), alu_multi),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reduce_multi),
(UPat(Ops.RESHAPE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), reshape_multi),
(UPat(Ops.EXPAND, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), expand_multi),
(UPat(Ops.PAD, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), pad_multi),
(UPat(Ops.PERMUTE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), permute_multi),
(UPat(Ops.SHRINK, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), shrink_multi),
(UPat(Ops.STRIDE, src=(UPat(Ops.MULTI, name="multi"), ), name="root"), stride_multi),
(UPat(Ops.ASSIGN, src=(UPat(Ops.MULTI, name="dest"), UPat(Ops.MULTI, name="src"))), assign_multi),
(UPat(Ops.COPY, src=(UPat(Ops.DEVICE, name="device"), UPat(Ops.MULTI, name="multi"), )), copy_multi),
(UPat((Ops.CAST, Ops.BITCAST, Ops.CONTIGUOUS, Ops.DETACH), src=(UPat(Ops.MULTI, name="multi"), ), name="root"), passthrough_multi),
])
def permute(self, arg:tuple[int, ...]):
# all permutes supported!
return MultiLazyBuffer([x.permute(arg) for x in self.lbs], arg.index(self.axis) if self.axis is not None else None, self.real)
def shrink(self, arg:tuple[tuple[sint, sint], ...]):
assert self.axis is None or arg[self.axis] == (0, self.shape[self.axis]) or arg[self.axis] in self.bounds, f"shrinking not supported for {arg=}"
if self.axis is not None and arg[self.axis] in self.bounds and arg[self.axis] != (0, self.shape[self.axis]):
assert all(arg[i] == (0, s) or i == self.axis for i,s in enumerate(self.shape)), "cannot shrink sharded and non-sharded axis at the same time"
# NOTE: shrink on the shard axis is only allowed when result is a single partition, denoted by the new real
idx = self.bounds.index(arg[self.axis])
# zero out other lbs to not create lb reference
return MultiLazyBuffer([lb if i==idx else lb.const_like(0) for i,lb in enumerate(self.lbs)], self.axis, [i==idx for i in range(len(self.lbs))])
return MultiLazyBuffer([x.shrink(tuple((0, x.shape[self.axis]) if a == self.axis else s for a,s in enumerate(arg))) for x in self.lbs],
self.axis, self.real)
def stride(self, arg:tuple[int, ...]):
assert self.axis is None or arg[self.axis] == 1, "flipping not supported on sharded axis"
return MultiLazyBuffer([x.stride(arg) for x in self.lbs], self.axis, self.real)
@track_rewrites(named=True)
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]: return {k:v for k,v in graph_rewrite_map(big_sink, multi_pm).items() if k is not v}

View File

@@ -77,9 +77,6 @@ class LARS(Optimizer):
def schedule_step_with_grads(self, grads:list[Tensor]) -> list[Tensor]:
for i, (t, g) in enumerate(zip(self.params, grads)):
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = g.contiguous()
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()

View File

@@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm, round_up, T
from tinygrad.shape.view import strides_for_shape
from tinygrad.multi import MultiLazyBuffer
class TensorIO(io.RawIOBase, BinaryIO):
def __init__(self, t: Tensor):
@@ -152,9 +151,9 @@ def load_state_dict(model, state_dict:dict[str, Tensor], strict=True, verbose=Tr
continue
if v.shape != state_dict[k].shape:
raise ValueError(f'Shape mismatch in layer `{k}`: Expected shape {v.shape}, but found {state_dict[k].shape} in state dict.')
if isinstance((mlb:=v.lazydata), MultiLazyBuffer):
if isinstance(state_dict[k].lazydata, MultiLazyBuffer): v.replace(state_dict[k]).realize()
else: v.replace(state_dict[k].shard(mlb.device, mlb.axis)).realize()
if isinstance(v.device, tuple):
if isinstance(state_dict[k].device, tuple): v.replace(state_dict[k]).realize()
else: v.replace(state_dict[k].shard(v.device, v.lazydata.axis)).realize()
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]

View File

@@ -93,7 +93,7 @@ class MathTrait(SimpleMathTrait):
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
# uops that aren't rendered
SINK = auto(); CONTIGUOUS = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
SINK = auto(); CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto(); PRELOAD = auto() # noqa: E702
# TODO: empty continues to exist because of tensor
EMPTY = auto()
@@ -150,6 +150,7 @@ class Ops(FastEnum):
# device
DEVICE = auto()
MULTI = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
@@ -281,6 +282,10 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@functools.cached_property
def st(self) -> ShapeTracker|None:
from tinygrad.shape.shapetracker import ShapeTracker
if self.op is Ops.MULTI:
return ShapeTracker.from_shape(
tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape)))
# these ops define a ShapeTracker from the arg
if self.op is Ops.VIEW: return self.arg
if self.op in GroupOp.Movement: return unwrap(self.src[0].st).mop(self.op, self.arg)
@@ -294,7 +299,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# only reduce ops are allowed to change shape, everything else derives shape from sources
elif self.op in {Ops.REDUCE_AXIS, Ops.WMMA}: shape = src_sts[0].reduce(self.axis_arg)
else: shape = src_sts[0].shape
from tinygrad.shape.shapetracker import ShapeTracker
return ShapeTracker.from_shape(shape)
@functools.cached_property
@@ -350,7 +354,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def index(self, idx:UOp, valid:UOp|None=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def const_like(self, b:ConstLike):
# constants can optionally have a DEVICE source
return UOp.const(self.dtype, b) if self._device is None else UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
if self._device is None: return UOp.const(self.dtype, b)
if isinstance(self.device, tuple): return UOp.multi(*[UOp.metaop(Ops.CONST, self.shape, self.dtype, d, b) for d in self.device], axis=None)
return UOp.metaop(Ops.CONST, self.shape, self.dtype, self.device, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
if count == 1: return self
@@ -389,7 +395,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
new_shape = unwrap(self.st).reduce(axis)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
# TODO: this shouldn't be here, it belongs in scheduler! that's why it broke multi
if not SPLIT_REDUCEOP or isinstance(self._device, tuple) or not all_int(self.shape) or (0 in self.shape) or \
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis)
@@ -409,6 +416,46 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return splitted._reduce_op(op, axis)._reduce_op(op, (len(new_shape),)).reshape(new_shape) # reduce original axes, then split
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def contiguous(self): return self.alu(Ops.CONTIGUOUS)
def contiguous_backward(self): return self.alu(Ops.CONTIGUOUS_BACKWARD)
# *** from MultiLazyBuffer ***
def multi(self, *more:UOp, axis:int|None, real:tuple[bool,...]|None=None):
parents = (self,)+more
assert all_same([x.dtype for x in parents]), "multi parents must have the same dtype"
return UOp(Ops.MULTI, self.dtype, parents, (axis, real if real is not None else (True,)*len(parents)))
@property
def bounds(self):
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.src], initial=0)))
@property
def axis(self):
assert self.op is Ops.MULTI
return self.arg[0]
@property
def real(self):
assert self.op is Ops.MULTI
return self.arg[1]
@property
def real_lbs(self): return [lb for lb,r in zip(self.src, self.real) if r]
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> UOp:
if axis is None: lbs = [self] * len(devices)
else:
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
sz = self.shape[axis] // len(devices)
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
lbs, off = [], 0
for sz in sizes:
lbs.append(self.shrink(tuple((0,s) if i != axis else (off,off+sz) for i,s in enumerate(self.shape))))
off += sz
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
return UOp.multi(*[lb.contiguous() for lb in sharded_lbs], axis=axis)
# *** from LazyBuffer ***
@@ -426,7 +473,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return var.replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg=device),), ShapeTracker.from_shape(shape)),)).bind(val)
# otherwise it's just a VIEW(BUFFER)
return UOp(Ops.VIEW, dtype, (UOp.new_buffer(device, (st:=ShapeTracker.from_shape(shape)).size, dtype),), st)
def copy_to_device(self, device:str, clone:bool=False) -> UOp:
def copy_to_device(self, device:str|tuple[str, ...], clone:bool=False) -> UOp:
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.shape) < prod(self.base.shape): return self.contiguous().copy_to_device(device)
# COPY is COPY(DEVICE, copyin.base) -> VIEW(copyin.st)
@@ -440,8 +487,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def clone(self) -> UOp: return self.copy_to_device(self.device, clone=True)
@property
def lbs(self): return [self]
@property
def metadata(self): return all_metadata.get(self, None)
# *** uop movement ops ***
@@ -470,10 +515,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@staticmethod
def new_buffer(device:str, size:int, dtype:DType): return UOp(Ops.BUFFER, dtype, (UOp(Ops.DEVICE, arg=device),), (next(UOp.buffer_num), size))
@property
def device(self) -> str: return unwrap(self._device)
def device(self) -> str|tuple[str, ...]: return cast(str|tuple[str, ...], unwrap(self._device))
@functools.cached_property
def _device(self) -> Optional[str]:
def _device(self) -> Optional[str|tuple[str, ...]]:
if self.op is Ops.DEVICE: return self.arg
if self.op is Ops.MULTI: return tuple(cast(str, x.device) for x in self.src)
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
@@ -489,6 +535,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
if (cret:=buffers.get(self)) is not None: return cret
from tinygrad.device import Buffer
assert isinstance(self.device, str), f"buffer not supported on multi {self.device}"
buffers[self] = ret = Buffer(self.device, self.size, self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base)
return ret
@property
@@ -496,7 +543,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is Ops.BUFFER: return self.src[0].realized
return self.buffer if self.op is Ops.BUFFER else None
@property
def is_realized(self) -> bool: return self.base.realized is not None
def is_realized(self) -> bool:
return all(x.base.realized is not None for x in self.base.real_lbs) if self.base.op is Ops.MULTI else self.base.realized is not None
# *** uop Variable stuff ***
@@ -639,7 +687,7 @@ def print_uops(uops:list[UOp]):
def get_location() -> tuple[str, int]:
frm = sys._getframe(1)
# find the real frame in the file that has the UPat, TODO: is there a better way to do this?
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py",
while frm.f_back is not None and pathlib.Path(frm.f_back.f_code.co_filename).name in {"ops.py", "rewriter.py", "schedule.py", "multi.py",
"lowerer.py", "cstyle.py", "linearize.py"}:
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno
@@ -775,9 +823,9 @@ TRACK_MATCH_STATS = ContextVar("TRACK_MATCH_STATS", 2 if getenv("VIZ") else 0)
match_stats:dict[UPat, list[Union[int, float]]] = dict()
@dataclass(frozen=True)
class TrackedGraphRewrite:
loc: tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink input to graph_rewrite
matches: list[tuple[UOp, Optional[UOp], Optional[UPat], float]] = field(default_factory=list) # before+after of all the matches
loc: tuple[str, int] # location that called graph_rewrite
sink: UOp # the sink input to graph_rewrite
matches: list[tuple[UOp, UOp, UPat]] = field(default_factory=list) # before+after of all the matches
tracked_keys:list[Any] = []
tracked_ctxs:list[list[TrackedGraphRewrite]] = []
_name_cnt:dict[str, int] = {}
@@ -808,10 +856,9 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][0] += 1
match_stats[p][3] += (et:=time.perf_counter()-st)
if TRACK_MATCH_STATS >= 3: print(f"{et*1e6:7.2f} us -- ", p.printable())
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p, et))
if TRACK_MATCH_STATS >= 2 and isinstance(ret, UOp) and len(tracked_ctxs) != 0: tracked_ctxs[-1][-1].matches.append((uop, ret, p))
return ret # NOTE: if it returns None, we keep trying to match
match_stats[p][2] += time.perf_counter()-st
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0 and len(tracked_ctxs[-1]) != 0: tracked_ctxs[-1][-1].matches.append((uop, None, None, 0))
return None
if TRACK_MATCH_STATS:
@@ -849,7 +896,7 @@ class RewriteContext:
self.replace: dict[UOp, UOp] = {}
def top_down_rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_src = tuple(map(self.top_down_rewrite, n.src))
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
new_n = self.pm.rewrite(n, self.ctx) if new_src == n.src else UOp(n.op, n.dtype, new_src, n.arg)
self.replace[n] = ret = n if new_n is None else self.top_down_rewrite(new_n)
return ret
@@ -857,7 +904,7 @@ class RewriteContext:
if (rn := self.replace.get(n)) is not None: return rn
new_n: UOp|None = n
while new_n is not None: last_n, new_n = new_n, self.pm.rewrite(new_n, self.ctx)
new_src = tuple(map(self.bottom_up_rewrite, last_n.src))
new_src = tuple([self.bottom_up_rewrite(x) for x in last_n.src])
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
return ret
@@ -1269,18 +1316,24 @@ Variable = UOp
ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
# *** uop swizzling ***
# *** UOp merge views and swizzling ***
merge_views = PatternMatcher([
(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)),
(UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None),
# VIEW(VIEW) merges to a single VIEW
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.VIEW, name="vm2"),)), lambda vm1,vm2: vm2.replace(arg=vm2.st+vm1.st)),
(UPat(Ops.VIEW, name="vm", src=(UPat.var("x"),)), lambda vm,x: x if vm.st.contiguous and x.st is not None and x.shape == vm.shape else None),
# merge unmasked const views
(UPat(Ops.VIEW, name="view", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
])
# push VIEW to loads
# push VIEW to parents
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
# VIEW(CONST) becomes VALID
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.replace(src=()).valid(vm.st)),
# VIEW before elementwise/buffer ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])

View File

@@ -198,12 +198,12 @@ class AMDCopyQueue(HWQueue):
return self
def bind(self, dev:AMDDevice):
if not dev.driverless: return
if not getenv("AMD_SDMA_BIND", 0) or not dev.driverless: return
self.binded_device = dev
self.hw_page = dev.allocator.alloc((qsz:=round_up(len(self._q), 8)) * 4, BufferSpec(cpu_access=True, nolru=True, uncached=True))
hw_view = to_mv(self.hw_page.va_addr, self.hw_page.size).cast("I")
for i, value in enumerate(self._q): hw_view[i] = value
for i in range(qsz): hw_view[i] = self._q[i] if i < len(self._q) else 0
self.indirect_cmd = [amd_gpu.SDMA_OP_INDIRECT | amd_gpu.SDMA_PKT_INDIRECT_HEADER_VMID(0), *data64_le(self.hw_page.va_addr), qsz, *data64_le(0)]
self._q, self.cmd_sizes = hw_view, [len(self.indirect_cmd)]
@@ -464,7 +464,7 @@ class PCIIface:
iommu_group = HWInterface.readlink(f"/sys/bus/pci/devices/{self.pcibus}/iommu_group").split('/')[-1]
except OSError:
if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (not inserted or no-iommu mode is not supported).")
if DEBUG >= 1: print(f"am {self.pcibus}: failed to init vfio-pci module (run `sudo modprobe vfio-pci`).")
PCIIface.vfio = False
# Init vfio for the device
@@ -487,21 +487,18 @@ class PCIIface:
self.pagemap = HWInterface("/proc/self/pagemap", os.O_RDONLY)
self.bar_fds = {bar: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource{bar}", os.O_RDWR | os.O_SYNC) for bar in [0, 2, 5]}
self.adev = AMDev(self.pcidev, self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
self.adev = AMDev(self.pcibus, self._map_pci_range(0), dbell:=self._map_pci_range(2).cast('Q'), self._map_pci_range(5).cast('I'))
self.doorbell_cpu_addr = mv_address(dbell)
libpciaccess.pci_device_cfg_read_u16(self.adev.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
libpciaccess.pci_device_cfg_write_u16(self.adev.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
libpciaccess.pci_device_cfg_read_u16(self.pcidev, ctypes.byref(val:=ctypes.c_uint16()), libpciaccess.PCI_COMMAND)
libpciaccess.pci_device_cfg_write_u16(self.pcidev, val.value | libpciaccess.PCI_COMMAND_MASTER, libpciaccess.PCI_COMMAND)
# TODO: this is for 7900xtx, the only tested card.
self.props = {'simd_count': 192, 'simd_per_cu': 2, 'max_waves_per_simd': 16, 'gfx_target_version': 110000, 'max_slots_scratch_cu': 32,
'array_count': 12, 'simd_arrays_per_engine': 2, 'lds_size_in_kb': 64}
def _map_pci_range(self, bar, off=0, addr=0, size=None):
if PCIIface.vfio:
vfio.VFIO_DEVICE_GET_REGION_INFO(self.vfio_dev, reg:=vfio.struct_vfio_region_info(argsz=ctypes.sizeof(vfio.struct_vfio_region_info), index=bar))
fd, sz, off = self.vfio_dev, size or reg.size, reg.offset + off
else: fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size
fd, sz = self.bar_fds[bar], size or self.pcidev.regions[bar].size
return to_mv(fd.mmap(addr, sz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | (MAP_FIXED if addr else 0), off), sz)
def alloc(self, size:int, host=False, uncached=False, cpu_access=False):

View File

@@ -102,19 +102,16 @@ class AMFirmware:
class AMMapping: va_addr:int; size:int; paddrs:list[tuple[int, int]]; uncached:bool=False; system:bool=False; snooped:bool=False # noqa: E702
class AMPageTableEntry:
def __init__(self, adev, paddr, lv): self.paddr, self.view, self.lv = paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
def __init__(self, adev, paddr, lv): self.adev, self.paddr, self.entries, self.lv = adev, paddr, to_mv(adev.paddr2cpu(paddr), 0x1000).cast('Q'), lv
def set_table(self, entry_id, pte:AMPageTableEntry, valid=True):
self.view[entry_id] = (pte.paddr & 0x0000FFFFFFFFF000) | (am.AMDGPU_PTE_VALID if valid else 0)
def set_entry(self, entry_id:int, paddr:int, table=False, uncached=False, system=False, snooped=False, frag=0, valid=True):
assert paddr & self.adev.gmc.address_space_mask == paddr, f"Invalid physical address {paddr:#x}"
def set_page(self, entry_id, paddr, uncached=False, system=False, snooped=False, frag=0, valid=True):
f = (am.AMDGPU_PTE_VALID if valid else 0) | am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE \
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if self.lv != am.AMDGPU_VM_PTB else 0) \
f = (am.AMDGPU_PTE_VALID if valid else 0) | ((am.AMDGPU_PTE_WRITEABLE | am.AMDGPU_PTE_READABLE | am.AMDGPU_PTE_EXECUTABLE) if not table else 0) \
| am.AMDGPU_PTE_FRAG(frag) | (am.AMDGPU_PDE_PTE if not table and self.lv != am.AMDGPU_VM_PTB else 0) \
| ((am.AMDGPU_PTE_SYSTEM) if system else 0) | ((am.AMDGPU_PTE_SNOOPED) if snooped else 0) \
| (am.AMDGPU_PTE_MTYPE_NV10(0, am.MTYPE_UC) if uncached else 0)
self.view[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
def get_entry(self, entry_id): return self.view[entry_id]
self.entries[entry_id] = (paddr & 0x0000FFFFFFFFF000) | f
class AMPageTableTraverseContext:
def __init__(self, adev, pt, vaddr, create_pts=False, free_pts=False):
@@ -126,22 +123,23 @@ class AMPageTableTraverseContext:
def level_down(self):
pt, pte_idx, _ = self.pt_stack[-1]
if (entry:=pt.get_entry(pte_idx)) & am.AMDGPU_PTE_VALID:
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
else:
if (entry:=pt.entries[pte_idx]) & am.AMDGPU_PTE_VALID == 0:
assert self.create_pts, "Not allowed to create new page table"
pt.set_table(pte_idx, child_page_table:=AMPageTableEntry(self.adev, self.adev.mm.palloc(0x1000, zero=True), lv=pt.lv+1))
pt.set_entry(pte_idx, self.adev.mm.palloc(0x1000, zero=True), table=True, valid=True)
entry = pt.entries[pte_idx]
assert entry & am.AMDGPU_PDE_PTE == 0, f"Must be table pt={pt.paddr:#x}, {pte_idx=} {entry=:#x}"
child_page_table = AMPageTableEntry(self.adev, entry & 0x0000FFFFFFFFF000, lv=pt.lv+1)
self.pt_stack.append((child_page_table, self._pt_pte_idx(child_page_table, self.vaddr), self._pt_pte_size(child_page_table)))
return self.pt_stack[-1]
def _try_free_pt(self) -> bool:
pt, _, _ = self.pt_stack[-1]
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.get_entry(i) & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
if self.free_pts and pt != self.adev.mm.root_page_table and all(pt.entries[i] & am.AMDGPU_PTE_VALID == 0 for i in range(512)):
self.adev.mm.pfree(pt.paddr)
parent_pt, parent_pte_idx, _ = self.pt_stack[-2]
parent_pt.set_page(parent_pte_idx, 0x0, valid=False)
parent_pt.set_entry(parent_pte_idx, 0x0, valid=False)
return True
return False
@@ -156,7 +154,7 @@ class AMPageTableTraverseContext:
if self.create_pts:
while pte_covers > size: pt, pte_idx, pte_covers = self.level_down()
else:
while pt.lv!=am.AMDGPU_VM_PTB and (pt.get_entry(pte_idx)&am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
while pt.lv!=am.AMDGPU_VM_PTB and (pt.entries[pte_idx] & am.AMDGPU_PDE_PTE != am.AMDGPU_PDE_PTE): pt, pte_idx, pte_covers = self.level_down()
entries = min(size // pte_covers, 512 - pte_idx)
assert entries > 0, "Invalid entries"
@@ -173,7 +171,7 @@ class AMMemoryManager:
self.adev, self.vram_size = adev, vram_size
self.boot_allocator = TLSFAllocator(32 << 20, base=vram_size - (64 << 20)) # per device
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=True, boot=True), lv=am.AMDGPU_VM_PDB1)
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
assert size == sum(p[1] for p in paddrs), f"Size mismatch {size=} {sum(p[1] for p in paddrs)=}"
@@ -181,10 +179,10 @@ class AMMemoryManager:
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, create_pts=True)
for paddr, psize in paddrs:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
frag = 0 if pte_covers == 0x1000 else 0x9
for pte_off in range(pte_cnt):
assert pt.get_entry(pte_idx + pte_off) & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.get_entry(pte_idx + pte_off):#x}"
pt.set_page(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped, frag=frag, valid=True)
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
# Invalidate TLB after mappings.
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
@@ -197,8 +195,8 @@ class AMMemoryManager:
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, vaddr, free_pts=True)
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(size):
for pte_id in range(pte_idx, pte_idx + pte_cnt):
assert pt.get_entry(pte_id) & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.get_entry(pte_id):#x}"
pt.set_page(pte_id, paddr=0x0, valid=False)
assert pt.entries[pte_id] & am.AMDGPU_PTE_VALID == am.AMDGPU_PTE_VALID, f"PTE not mapped: {pt.entries[pte_id]:#x}"
pt.set_entry(pte_id, paddr=0x0, valid=False)
@staticmethod
def alloc_vaddr(size:int, align=0x1000) -> int: return AMMemoryManager.va_allocator.alloc(size, max((1 << (size.bit_length() - 1)), align))
@@ -233,8 +231,8 @@ class AMMemoryManager:
def pfree(self, paddr:int): self.pa_allocator.free(paddr)
class AMDev:
def __init__(self, pcidev, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
self.pcidev, self.devfmt = pcidev, devfmt
def __init__(self, devfmt, vram_bar:memoryview, doorbell_bar:memoryview, mmio_bar:memoryview):
self.devfmt = devfmt
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
@@ -258,8 +256,8 @@ class AMDev:
# To enable this, AM uses a separate boot memory that is guaranteed not to be overwritten. This physical memory is utilized for
# all blocks that are initialized only during the initial AM boot.
# To determine if the GPU is in the third state, AM uses regSCRATCH_REG7 as a flag.
self.is_booting = True # During boot only boot memory can be allocated. This flag is to validate this.
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000001)) and (getenv("AM_RESET", 0) != 1)
self.is_booting, self.smi_dev = True, False # During boot only boot memory can be allocated. This flag is to validate this.
self.partial_boot = (self.reg("regSCRATCH_REG7").read() == (am_version:=0xA0000002)) and (getenv("AM_RESET", 0) != 1)
# Memory manager & firmware
self.mm = AMMemoryManager(self, self.vram_size)

View File

@@ -25,6 +25,9 @@ class AM_GMC(AM_IP):
self.vm_base = self.adev.mm.va_allocator.base
self.vm_end = self.vm_base + self.adev.mm.va_allocator.size - 1
# GFX11 has 44-bit address space
self.address_space_mask = (1 << 44) - 1
self.memscratch_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.dummy_page_paddr = self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True)
self.hub_initted = {"MM": False, "GC": False}
@@ -99,7 +102,13 @@ class AM_GMC(AM_IP):
if self.adev.reg(f"reg{ip}VM_L2_PROTECTION_FAULT_STATUS").read(): raise RuntimeError(f"{ip}VM_L2_PROTECTION_FAULT_STATUS: {st:#x} {va:#x}")
class AM_SMU(AM_IP):
def __init__(self, adev):
super().__init__(adev)
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
def init(self):
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
for clck in [0x00000C94, 0x000204E1, 0x000105DC, 0x00050B76, 0x00070B76, 0x00040898, 0x00060898, 0x000308FD]:
@@ -115,6 +124,11 @@ class AM_SMU(AM_IP):
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
time.sleep(0.5) # 500ms
def read_table(self, table_t, cmd):
self._smu_cmn_send_smc_msg_with_param(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
def _smu_cmn_send_msg(self, msg, param=0):
self.adev.mmMP1_SMN_C2PMSG_90.write(0) # resp reg
@@ -237,8 +251,9 @@ class AM_IH(AM_IP):
def __init__(self, adev):
super().__init__(adev)
self.ring_size = 512 << 10
self.rings = [(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "", 0),
(self.adev.mm.palloc(self.ring_size, boot=True), self.adev.mm.palloc(0x1000, boot=True), "_RING1", 1)]
def _alloc_ring(size): return (self.adev.mm.palloc(size, zero=not self.adev.partial_boot, boot=True),
self.adev.mm.palloc(0x1000, zero=not self.adev.partial_boot, boot=True))
self.rings = [(*_alloc_ring(self.ring_size), "", 0), (*_alloc_ring(self.ring_size), "_RING1", 1)]
def interrupt_handler(self):
_, rwptr_vm, suf, _ = self.rings[0]
@@ -315,6 +330,9 @@ class AM_PSP(AM_IP):
self.ring_size = 0x10000
self.ring_paddr = self.adev.mm.palloc(self.ring_size, zero=not self.adev.partial_boot, boot=True)
self.max_tmr_size = 0x1300000
self.tmr_paddr = self.adev.mm.palloc(self.max_tmr_size, align=am.PSP_TMR_ALIGNMENT, zero=not self.adev.partial_boot, boot=True)
def is_sos_alive(self): return self.adev.regMP0_SMN_C2PMSG_81.read() != 0x0
def init(self):
sos_components_load_order = [
@@ -359,7 +377,7 @@ class AM_PSP(AM_IP):
# Load TOC and calculate TMR size
self._prep_msg1(fwm:=self.adev.fw.sos_fw[am.PSP_FW_TYPE_PSP_TOC])
self.tmr_size = self._load_toc_cmd(len(fwm)).resp.tmr_size
self.tmr_paddr = self.adev.mm.palloc(self.tmr_size, align=am.PSP_TMR_ALIGNMENT, boot=True)
assert self.tmr_size <= self.max_tmr_size
def _ring_create(self):
# If the ring is already created, destroy it

View File

@@ -115,8 +115,9 @@ class ShapeTracker:
def unit_stride_axes(self, ignore_valid=False) -> list[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
with Context(TRACK_MATCH_STATS=0):
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, symbolic_flat).toposort if x.op is Ops.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:

View File

@@ -1,15 +1,15 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from typing import List, Tuple, Callable, Optional, ClassVar, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import MultiLazyBuffer
from tinygrad.helpers import IMAGE, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
from tinygrad.multi import get_multi_map
from tinygrad.gradient import compute_gradient
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
from tinygrad.device import Device, Buffer, BufferSpec
from tinygrad.device import Device, BufferSpec
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
@@ -30,45 +30,23 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
# link the found UOps back to Tensors. exit early if there's no Tensors to realize
# NOTE: this uses all_tensors, but it's fast
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and any(x in all_uops for x in t.lazydata.lbs)]
fixed_tensors: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and t.lazydata in all_uops]
if len(fixed_tensors):
# potentially rewrite all the discovered Tensors
sink = UOp.sink(*[UOp.sink(*t.lazydata.lbs) if isinstance(t.lazydata, MultiLazyBuffer) else t.lazydata for t in fixed_tensors])
sink = UOp.sink(*[t.lazydata for t in fixed_tensors])
new_sink = sink.substitute(applied_map)
# set the relevant lazydata to the realized UOps
for t,s,ns in zip(fixed_tensors, sink.src, new_sink.src):
if s is ns: continue
if isinstance(t.lazydata, MultiLazyBuffer): t.lazydata.lbs = list(ns.src)
else: t.lazydata = ns
t.lazydata = ns
# **** start with two base classes, Tensor and Function ****
class Function:
def __init__(self, device:Union[str, tuple[str, ...]], *tensors:Tensor, metadata:Optional[Metadata]=None):
self.device = device
self.needs_input_grad = [t.requires_grad for t in tensors]
self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False
if self.requires_grad: self.parents = tensors
self.metadata = metadata
def forward(self, *args, **kwargs): raise NotImplementedError(f"forward not implemented for {type(self)}")
def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implemented for {type(self)}")
@classmethod
def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor:
ctx = fxn(x[0].device, *x, metadata=_METADATA.get())
ret = Tensor.__new__(Tensor)
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
return ret
import tinygrad.function as F
# **** Tensor helper functions ****
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
@@ -148,8 +126,7 @@ class Tensor(SimpleMathTrait):
np.set_printoptions(precision=4)
```
"""
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
__deletable__ = ('_ctx',)
__slots__ = "lazydata", "requires_grad", "grad"
training: ClassVar[bool] = False
no_grad: ClassVar[bool] = False
@@ -159,7 +136,7 @@ class Tensor(SimpleMathTrait):
return instance
def __del__(self): all_tensors.discard(weakref.ref(self))
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
if dtype is not None: dtype = to_dtype(dtype)
if device is None and isinstance(data, pathlib.Path): device = f"DISK:{data.resolve()}" # keep it on the disk if device is None
@@ -172,11 +149,8 @@ class Tensor(SimpleMathTrait):
# None (the default) will be updated to True if it's put in an optimizer
self.requires_grad: Optional[bool] = requires_grad
# internal variable used for autograd graph construction
self._ctx: Optional[Function] = None
# create a LazyBuffer from the different types of inputs
if isinstance(data, (UOp, MultiLazyBuffer)):
if isinstance(data, UOp):
assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported"
# NOTE: this is here because LazyBuffer = UOp
if isinstance(data, UOp) and data.op is Ops.BIND: data = _metaop(Ops.BIND, tuple(), dtype or data.dtype, device, data)
@@ -199,12 +173,12 @@ class Tensor(SimpleMathTrait):
data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}")
# by this point, it has to be a LazyBuffer
if not isinstance(data, (UOp, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
if not isinstance(data, UOp): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}")
# data might be on a different device
if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
if isinstance(device, str): self.lazydata:UOp = data if data.device == device else data.copy_to_device(device)
# if device is a tuple, we should have/construct a MultiLazyBuffer
elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata
elif isinstance(data, UOp) and isinstance(data.device, str): self.lazydata = Tensor(data).shard(device).lazydata
else:
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata = data
@@ -224,8 +198,8 @@ class Tensor(SimpleMathTrait):
def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev
def __repr__(self):
if isinstance(ld:=self.lazydata, MultiLazyBuffer): ld_repr = f"{ld!r}"
else: ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
ld = self.lazydata
ld_repr = f"<UOp {ld.device} {ld.shape} {str(ld.dtype)[7:]} {ld.st if ld.base is not ld else (ld.op, ld.realized)}>"
return f"<Tensor {ld_repr} on {self.device} with grad {(self.grad.lazydata if self.grad is not None else None)!r}>"
# Python has a non moving GC, so this should be okay
@@ -246,6 +220,17 @@ class Tensor(SimpleMathTrait):
@property
def dtype(self) -> DType: return self.lazydata.dtype
def _apply_uop(self, fxn:Callable, *x:Tensor, **kwargs) -> Tensor:
ret = Tensor.__new__(Tensor)
needs_input_grad = [t.requires_grad for t in (self,)+x]
ret.requires_grad, ret.grad = True if any(needs_input_grad) else None if None in needs_input_grad else False, None
ret.lazydata = fxn(*[t.lazydata for t in (self,)+x], **kwargs)
return ret
def _apply_broadcasted_uop(self, fxn:Callable, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
lhs,rhs = self._broadcasted(x, reverse)
return lhs._apply_uop(fxn, rhs)
# ***** data handlers ****
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
@@ -254,7 +239,14 @@ class Tensor(SimpleMathTrait):
NOTE: A Tensor can only be scheduled once.
"""
big_sink = UOp.sink(*flatten([x.lazydata.lbs for x in (self,)+lst]))
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
# TODO: move this to scheduler tensor_map pass
if any(x.op is Ops.MULTI for x in big_sink.toposort):
# multi fixup
_apply_map_to_tensors(get_multi_map(big_sink))
big_sink = UOp.sink(*flatten([x.lazydata.src if x.lazydata.op is Ops.MULTI else [x.lazydata] for x in (self,)+lst]))
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
_apply_map_to_tensors(becomes_map)
return memory_planner(schedule), var_vals
@@ -275,7 +267,6 @@ class Tensor(SimpleMathTrait):
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
"""
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
assert getattr(self, '_ctx', None) is None
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
self.lazydata = x.lazydata
return self
@@ -287,13 +278,11 @@ class Tensor(SimpleMathTrait):
self.contiguous().realize().lazydata.base.realized.copyin(x._data())
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata is x.lazydata: return self # a self assign is a NOOP
# NOTE: we allow cross device assign
assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}"
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer"
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized: return self.replace(x)
self.lazydata = self.lazydata.assign(x.lazydata)
@@ -309,7 +298,8 @@ class Tensor(SimpleMathTrait):
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
buf = cast(Buffer, cast(UOp, cpu.lazydata).base.realized)
buf = cast(UOp, cpu.lazydata).base.realized
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
@@ -373,7 +363,6 @@ class Tensor(SimpleMathTrait):
"""
ret = Tensor(self.lazydata.clone(), self.device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.clone()
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to(self, device:Optional[Union[str, tuple[str, ...]]]) -> Tensor:
@@ -385,7 +374,6 @@ class Tensor(SimpleMathTrait):
if not isinstance(device, str): return self.shard(device)
ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad)
if self.grad is not None: ret.grad = self.grad.to(device)
if hasattr(self, '_ctx'): ret._ctx = self._ctx
return ret
def to_(self, device:Optional[Union[str, tuple[str, ...]]]):
@@ -405,18 +393,9 @@ class Tensor(SimpleMathTrait):
print(t.shard((t.device, t.device), axis=1).lazydata)
```
"""
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
assert isinstance(self.device, str), "can't shard a MultiLazyBuffer"
devices = tuple(Device.canonicalize(x) for x in devices)
if axis is None: lbs = [self.lazydata] * len(devices)
else:
axis = self._resolve_dim(axis)
if self.shape[axis] % len(devices) != 0: raise RuntimeError(f"multi axis uneven: {self.shape[axis]=} {axis=} {len(devices)=}")
sz = self.shape[axis] // len(devices)
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
mlb = MultiLazyBuffer([lb.contiguous() for lb in sharded_lbs], axis)
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
@@ -439,7 +418,7 @@ class Tensor(SimpleMathTrait):
def _metaop(op, shape, device:Optional[Union[tuple[str, ...], str]]=None, dtype:Optional[DTypeLike]=None, arg=None, **kwargs):
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
if isinstance(device, tuple):
return Tensor(MultiLazyBuffer([UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], None),
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
device, dtype, **kwargs)
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
@@ -510,7 +489,7 @@ class Tensor(SimpleMathTrait):
@staticmethod
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@@ -750,12 +729,12 @@ class Tensor(SimpleMathTrait):
```
"""
dtype = kwargs.pop("dtype", self.dtype)
if isinstance(self.device, tuple) and isinstance(self.lazydata, MultiLazyBuffer):
if isinstance(self.device, tuple):
if kwargs.get("device") is not None: raise RuntimeError("cannot specify `device` on `rand_like` of a multi device tensor")
if self.lazydata.axis is None: return Tensor.rand(*self.shape, dtype=dtype, **kwargs).shard(self.device)
contiguous = kwargs.pop("contiguous", True)
rands = [Tensor.rand(*lb.shape, device=lb.device, dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.lbs]
return Tensor(MultiLazyBuffer(cast(list[UOp], rands), self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
rands = [Tensor.rand(*lb.shape, device=cast(str, lb.device), dtype=dtype, contiguous=contiguous, **kwargs).lazydata for lb in self.lazydata.src]
return Tensor(UOp.multi(*rands, axis=self.lazydata.axis), device=self.device, dtype=dtype, **kwargs)
return Tensor.rand(*self.shape, device=kwargs.pop("device", self.device), dtype=dtype, **kwargs)
# ***** rng hlops *****
@@ -904,7 +883,7 @@ class Tensor(SimpleMathTrait):
# ***** toposort and backward pass *****
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None) -> list[Tensor]:
def gradient(self, *targets:Tensor, gradient:Optional[Tensor]=None, materialize_grads=False) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
@@ -921,29 +900,17 @@ class Tensor(SimpleMathTrait):
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
rets = []
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
target_uops = [x.lazydata.lbs[i] for x in targets]
grads = compute_gradient(uop, grad, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
ret.append(y)
rets.append(ret)
target_uops = [x.lazydata for x in targets]
grads = compute_gradient(self.lazydata, gradient.lazydata, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None:
if materialize_grads: y = x.const_like(0)
else: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
ret.append(y)
rets.append(ret)
# create returned Tensors
if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
device=t.device) for t,u in zip(targets, zip(*rets))]
def _deepwalk(self) -> list[Tensor]:
def _walk(node:Tensor, visited:set[Tensor]):
visited.add(node)
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in cast(Function, node._ctx).parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
"""
@@ -956,31 +923,13 @@ class Tensor(SimpleMathTrait):
print(t.grad.numpy())
```
"""
toposorted = self._deepwalk()
if gradient is None:
assert self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
toposort_uop = self.lazydata.toposort
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
self.grad = gradient
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
ctx = cast(Function, t0._ctx)
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
grads = ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(ctx.parents) == 1 else grads)]
for t, g in zip(ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
t.grad = g if t.grad is None else (t.grad + g)
if not retain_graph: del t0._ctx
all_uops = self.lazydata.toposort
tensors_need_grad: list[Tensor] = [t for tref in all_tensors if (t:=tref()) is not None and \
t.lazydata in all_uops and t.requires_grad and not Tensor.no_grad]
# clear contexts
for t,g in zip(tensors_need_grad, self.gradient(*tensors_need_grad, gradient=gradient, materialize_grads=True)):
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
t.grad = g if t.grad is None else (t.grad + g)
return self
# ***** movement low level ops *****
@@ -1004,7 +953,7 @@ class Tensor(SimpleMathTrait):
# resolve -1
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
return F.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self
return self._apply_uop(UOp.reshape, arg=new_shape) if new_shape != self.shape else self
def expand(self, shape, *args) -> Tensor:
"""
@@ -1037,7 +986,7 @@ class Tensor(SimpleMathTrait):
"""
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
return F.Permute.apply(self, order=order_arg)
return self._apply_uop(UOp.permute, arg=order_arg)
def flip(self, axis, *args) -> Tensor:
"""
@@ -1057,7 +1006,7 @@ class Tensor(SimpleMathTrait):
"""
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
return F.Flip.apply(self, axis=axis_arg)
return self._apply_uop(UOp.stride, arg=tuple([-1 if i in axis_arg else 1 for i in range(len(self.shape))]))
def shrink(self, arg:tuple[Optional[tuple[sint, sint]], ...]) -> Tensor:
"""
@@ -1077,7 +1026,7 @@ class Tensor(SimpleMathTrait):
```
"""
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
return F.Shrink.apply(self, arg=tuple(shrink_arg))
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
def pad(self, padding:Union[Sequence[sint], Sequence[Optional[tuple[sint, sint]]]], mode:str="constant", value:float=0.0) -> Tensor:
"""
@@ -1121,7 +1070,8 @@ class Tensor(SimpleMathTrait):
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0,v)
def _constant(x:Tensor,px,v):
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
@@ -1278,7 +1228,7 @@ class Tensor(SimpleMathTrait):
self._getitem(indices).assign(v)
return
# NOTE: check that setitem target is valid first
if not all(unwrap(lb.st).contiguous for lb in self.lazydata.lbs): raise RuntimeError("setitem target needs to be contiguous")
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
if not isinstance(v, (Tensor, float, int, bool)): raise TypeError(f"can't set a {type(v).__name__} to a Tensor")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
if self.requires_grad or v.requires_grad: raise NotImplementedError("setitem with requires_grad is not supported")
@@ -1611,10 +1561,10 @@ class Tensor(SimpleMathTrait):
# ***** reduce ops *****
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
def _reduce(self, op:Ops, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor:
axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1)))
if self.ndim == 0: axis = ()
ret = fxn.apply(self, axis=axis)
ret = self._apply_uop(UOp.r, op=op, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1641,7 +1591,7 @@ class Tensor(SimpleMathTrait):
print(t.sum(axis=1).numpy())
```
"""
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim)
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None):
@@ -1668,7 +1618,7 @@ class Tensor(SimpleMathTrait):
print(t.prod(axis=1).numpy())
```
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim)
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False):
"""
@@ -1691,7 +1641,7 @@ class Tensor(SimpleMathTrait):
print(t.max(axis=1, keepdim=True).numpy())
```
"""
return self._reduce(F.Max, axis, keepdim)
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
@@ -2528,7 +2478,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([False, True]).logical_not().numpy())
```
"""
return F.Neq.apply(*self.cast(dtypes.bool)._broadcasted(True))
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
def neg(self):
"""
Negates the tensor element-wise.
@@ -2542,12 +2492,12 @@ class Tensor(SimpleMathTrait):
"""
Returns a contiguous tensor.
"""
return F.Contiguous.apply(self)
return self._apply_uop(UOp.contiguous)
def contiguous_backward(self):
"""
Inserts a contiguous operation in the backward pass.
"""
return F.ContiguousBackward.apply(self)
return self._apply_uop(UOp.contiguous_backward)
def log(self):
"""
Computes the natural logarithm element-wise.
@@ -2558,7 +2508,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 4., 8.]).log().numpy())
```
"""
return F.Log.apply(self.cast(least_upper_float(self.dtype)))
return self.log2()*math.log(2)
def log2(self):
"""
Computes the base-2 logarithm element-wise.
@@ -2569,7 +2519,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 4., 8.]).log2().numpy())
```
"""
return self.log()/math.log(2)
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
def exp(self):
"""
Computes the exponential function element-wise.
@@ -2580,7 +2530,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., 1., 2., 3.]).exp().numpy())
```
"""
return F.Exp.apply(self.cast(least_upper_float(self.dtype)))
return self.mul(1/math.log(2)).exp2()
def exp2(self):
"""
Computes the base-2 exponential function element-wise.
@@ -2591,8 +2541,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., 1., 2., 3.]).exp2().numpy())
```
"""
return F.Exp.apply(self*math.log(2))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def relu(self):
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -2603,7 +2552,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).relu().numpy())
```
"""
return F.Relu.apply(self)
return (self>0).where(self, 0)
def sigmoid(self):
"""
@@ -2639,7 +2588,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 3., 4.]).sqrt().numpy())
```
"""
return F.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def rsqrt(self):
"""
Computes the reciprocal of the square root of the tensor element-wise.
@@ -2657,7 +2606,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([0., math.pi/2, math.pi, 3*math.pi/2, 2*math.pi]).sin().numpy())
```
"""
return F.Sin.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
def cos(self):
"""
Computes the cosine of the tensor element-wise.
@@ -2816,7 +2765,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-3., -2., -1., 0., 1., 2., 3.]).sign().numpy())
```
"""
return F.Sign.apply(self)
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
def abs(self):
"""
Computes the absolute value of the tensor element-wise.
@@ -2834,7 +2783,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([1., 2., 3., 4.]).reciprocal().numpy())
```
"""
return F.Reciprocal.apply(self.cast(least_upper_float(self.dtype)))
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.reciprocal)
# ***** activation functions *****
@@ -3112,7 +3061,7 @@ class Tensor(SimpleMathTrait):
# for each dimension, check either dim is 1, or it does not change
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return F.Expand.apply(self.reshape(shape), shape=new_shape)
return self.reshape(shape)._apply_uop(UOp.expand, arg=new_shape)
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> tuple[Tensor, Tensor]:
x: Tensor = self
@@ -3156,7 +3105,7 @@ class Tensor(SimpleMathTrait):
print(t.add(Tensor([[2.0], [3.5]])).numpy())
```
"""
return F.Add.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.add, x, reverse)
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3197,7 +3146,7 @@ class Tensor(SimpleMathTrait):
print(t.mul(Tensor([[-1.0], [2.0]])).numpy())
```
"""
return F.Mul.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.mul, x, reverse)
def idiv(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3210,7 +3159,7 @@ class Tensor(SimpleMathTrait):
print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy())
```
"""
return F.IDiv.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.idiv, x, reverse)
def div(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3245,7 +3194,7 @@ class Tensor(SimpleMathTrait):
```
"""
a, b = self._broadcasted(x, reverse)
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
return (r := a._apply_uop(UOp.mod, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3261,7 +3210,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.Xor.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.xor, x, reverse)
def bitwise_and(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3276,7 +3225,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseAnd.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.bitwise_and, x, reverse)
def bitwise_or(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
@@ -3291,7 +3240,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return F.BitwiseOr.apply(*self._broadcasted(x, reverse))
return self._apply_broadcasted_uop(UOp.bitwise_or, x, reverse)
def bitwise_not(self) -> Tensor:
"""
@@ -3422,7 +3371,7 @@ class Tensor(SimpleMathTrait):
elif isinstance(y, Tensor): y, x = y._broadcasted(x)
cond, x = self._broadcasted(x, match_dtype=False)
cond, y = cond._broadcasted(y, match_dtype=False)
return F.Where.apply(cond.cast(dtypes.bool), *x._broadcasted(y))
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Union[Tensor, ConstType]): return mask.where(value, self)
@@ -3452,9 +3401,9 @@ class Tensor(SimpleMathTrait):
def __ilshift__(self, x) -> Tensor: return self.assign(self.lshift(x))
def __irshift__(self, x) -> Tensor: return self.assign(self.rshift(x))
def __lt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, False))
def __gt__(self, x) -> Tensor: return F.Less.apply(*self._broadcasted(x, True))
def ne(self, x) -> Tensor: return F.Neq.apply(*self._broadcasted(x))
def __lt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, False)
def __gt__(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.__lt__, x, True)
def ne(self, x) -> Tensor: return self._apply_broadcasted_uop(UOp.ne, x, False)
def __eq__(self, x) -> Tensor: return self.eq(x) # type: ignore[override]
@@ -3800,8 +3749,8 @@ class Tensor(SimpleMathTrait):
"""
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
return self._apply_uop(UOp.cast, dtype=dtypes.int32)._apply_uop(UOp.cast, dtype=dt)
return self if self.dtype == dt else self._apply_uop(UOp.cast, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""
@@ -3826,7 +3775,7 @@ class Tensor(SimpleMathTrait):
tmp = self.bitcast(old_uint)
if ns > os: return functools.reduce(Tensor.add, (tmp[..., i::ns//os].cast(new_uint) << 8*i*os for i in range(ns//os))).bitcast(dtype)
return Tensor.stack(*(tmp>>8*i*ns for i in range(os//ns)), dim=-1).flatten(-2).cast(new_uint).bitcast(dtype)
return F.Cast.apply(self, dtype=dt, bitcast=True) if self.dtype != dt else self
return self._apply_uop(UOp.bitcast, dtype=dt) if self.dtype != dt else self
def float(self) -> Tensor:
"""
@@ -4000,5 +3949,5 @@ def _metadata_wrapper(fn):
if TRACEMETA >= 1:
for name, fn in inspect.getmembers(Tensor, inspect.isfunction):
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential"]: continue
if name in ["__class__", "__init__", "__new__", "__repr__", "backward", "sequential", "gradient"]: continue
setattr(Tensor, name, functools.wraps(fn)(_metadata_wrapper(fn)))

View File

@@ -301,17 +301,17 @@
const kernelListParent = document.querySelector(".container.kernel-list-parent");
const kernelList = document.querySelector(".container.kernel-list");
kernelList.innerHTML = "";
kernels.forEach((k, i) => {
kernels.forEach(([key, items], i) => {
const kernelUl = Object.assign(document.createElement("ul"), { key: `kernel-${i}`, className: i === currentKernel ? "active" : "",
style: "overflow-x: auto; cursor: initial;" });
if (i === currentKernel) {
requestAnimationFrame(() => kernelUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
}
const p = Object.assign(document.createElement("p"), { id: `kernel-${k[0].kernel_name}`, innerText: k[0].kernel_name ?? "UNPARENTED",
const p = Object.assign(document.createElement("p"), { id: `kernel-${key}`, innerText: key ?? "UNPARENTED",
style: "cursor: pointer;"});
kernelUl.appendChild(p)
k.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.upats.length}`, key: `uop-rewrite-${j}`,
items.forEach((u, j) => {
const rwUl = Object.assign(document.createElement("ul"), { innerText: `${toPath(u.loc)} - ${u.match_count}`, key: `uop-rewrite-${j}`,
className: (j === currentUOp && i == currentKernel) ? "active" : "" })
if (j === currentUOp) {
requestAnimationFrame(() => rwUl.scrollIntoView({ behavior: "auto", block: "nearest" }));
@@ -460,7 +460,7 @@
event.preventDefault()
currentUOp = 0;
currentRewrite = 0;
currentKernel = Math.min(Array.from(Object.keys(kernels)).length-1, currentKernel+1)
currentKernel = Math.min(kernels.length-1, currentKernel+1);
return main()
}
}
@@ -486,7 +486,7 @@
if (event.key == "ArrowDown") {
event.preventDefault()
currentRewrite = 0;
const totalUOps = kernels[currentKernel].length-1;
const totalUOps = kernels[currentKernel][1].length-1;
currentUOp = Math.min(totalUOps, currentUOp+1)
main()
}

View File

@@ -2,8 +2,7 @@
import multiprocessing, pickle, functools, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, decimal
from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from dataclasses import asdict, dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, TypedDict
from tinygrad.helpers import colored, getenv, to_function_name, tqdm, unwrap, word_wrap
from tinygrad.ops import TrackedGraphRewrite, UOp, Ops, lines, GroupOp
from tinygrad.codegen.kernel import Kernel
@@ -13,54 +12,30 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0",
Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4",
Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", Ops.MULTI: "#f6ccff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BLOCK: "#C4A484", Ops.BLOCKEND: "#C4A4A4", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0"}
# ** API spec
# VIZ API
@dataclass
class GraphRewriteMetadata:
"""Overview of a tracked rewrite to viz the sidebar"""
loc: tuple[str, int]
"""File_path, Lineno"""
code_line: str
"""The Python line calling graph_rewrite"""
kernel_name: str
"""The kernel calling graph_rewrite"""
upats: list[tuple[tuple[str, int], str, float]]
"""List of all the applied UPats"""
class GraphRewriteMetadata(TypedDict):
loc: tuple[str, int] # [path, lineno] calling graph_rewrite
match_count: int # total match count in this context
@dataclass
class GraphRewriteDetails(GraphRewriteMetadata):
"""Full details about a single call to graph_rewrite"""
uops: list[UOp]
graphs: list[dict]
"""Sink at every step of graph_rewrite + the json serialized version"""
diffs: list[list[str]]
""".diff style before and after of the rewritten UOp child"""
changed_nodes: list[list[int]]
"""Nodes that changed at every step of graph_rewrite"""
kernel_code: Optional[str]
"""The program after all rewrites"""
# ** API functions
graphs: list[dict] # JSON serialized UOp at every rewrite step
uops: list[str] # strigified UOp at every rewrite step
diffs: list[list[str]] # string diff of the single UOp that changed
changed_nodes: list[list[int]] # the changed UOp id + all its parents ids
code_line: str # source code calling graph_rewrite
kernel_code: str|None # optionally render the final kernel code
upats: list[tuple[tuple[str, int], str]]
# NOTE: if any extra rendering in VIZ fails, we don't crash
def pcall(fxn:Callable[..., str], *args, **kwargs) -> str:
try: return fxn(*args, **kwargs)
except Exception as e: return f"ERROR: {e}"
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]]:
kernels: dict[str, list[tuple[Any, TrackedGraphRewrite, GraphRewriteMetadata]]] = {}
for k,ctxs in tqdm(zip(keys, contexts), desc="preparing kernels"):
name = to_function_name(k.name) if isinstance(k, Kernel) else str(k)
for ctx in ctxs:
if ctx.sink.op is Ops.CONST: continue
upats = [(upat.location, upat.printable(), tm) for _,_,upat,tm in ctx.matches if upat is not None]
kernels.setdefault(name, []).append((k, ctx, GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats)))
return list(kernels.values())
def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
assert isinstance(x, UOp)
graph: dict[int, tuple[str, str, list[int], str, str]] = {}
@@ -80,27 +55,27 @@ def uop_to_json(x:UOp) -> dict[int, tuple[str, str, list[int], str, str]]:
else: label += f"\n{x.op.name}{idx} {x.arg}"
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x not in excluded], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph
def get_metadata(keys:list[Any], contexts:list[list[TrackedGraphRewrite]]) -> list[tuple[str, list[GraphRewriteMetadata]]]:
return [(to_function_name(k.name) if isinstance(k, Kernel) else str(k),
[{"loc": v.loc, "match_count": len(v.matches)} for v in vals]) for k,vals in zip(keys, contexts)]
@functools.lru_cache(None)
def _prg(k:Kernel): return k.to_program().src
def get_details(k:Any, ctx:TrackedGraphRewrite, metadata:GraphRewriteMetadata) -> GraphRewriteDetails:
g = GraphRewriteDetails(**asdict(metadata), uops=[ctx.sink], diffs=[], changed_nodes=[],
kernel_code=pcall(_prg, k) if isinstance(k, Kernel) else None, graphs=[])
ret:GraphRewriteDetails = {"uops":[pcall(str, sink:=ctx.sink)], "graphs":[uop_to_json(sink)], "code_line":lines(ctx.loc[0])[ctx.loc[1]-1].strip(),
"kernel_code":pcall(_prg, k) if isinstance(k, Kernel) else None, "diffs":[], "upats":[], "changed_nodes":[], **metadata}
replaces: dict[UOp, UOp] = {}
g.graphs.append(uop_to_json(sink:=g.uops[0]))
for i,(u0,u1,upat,_) in enumerate(tqdm(ctx.matches)):
# if the match didn't result in a rewrite we move forward
if u1 is None: continue
for i,(u0,u1,upat) in enumerate(tqdm(ctx.matches)):
replaces[u0] = u1
# first, rewrite this UOp with the current rewrite + all the matches in replaces
new_sink = sink.substitute(replaces)
# sanity check
if new_sink is sink: raise AssertionError(f"rewritten sink wasn't rewritten! {i} {unwrap(upat).location}")
# update ret data
g.graphs.append(new_sink_js:=uop_to_json(new_sink))
g.changed_nodes.append([id(x) for x in u1.toposort if id(x) in new_sink_js])
g.diffs.append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
g.uops.append(sink:=new_sink)
return g
ret["graphs"].append(new_sink_js:=uop_to_json(new_sink))
ret["changed_nodes"].append([id(x) for x in u1.toposort if id(x) in new_sink_js])
ret["diffs"].append(list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())))
ret["upats"].append((upat.location, upat.printable()))
# TODO: this is O(n^2)!
ret["uops"].append(str(sink:=new_sink))
return ret
# Profiler API
devices:dict[str, tuple[decimal.Decimal, decimal.Decimal, int]] = {}
@@ -150,13 +125,9 @@ class Handler(BaseHTTPRequestHandler):
elif url.path == "/kernels":
query = parse_qs(url.query)
if (qkernel:=query.get("kernel")) is not None:
g = get_details(*kernels[int(qkernel[0])][int(query["idx"][0])])
# TODO: this is O(n^2)!
uops_strs = [pcall(str,x) for x in tqdm(g.uops)]
# NOTE: don't use asdict because it's reserializing the uops
jret: Any = {"loc": g.loc, "code_line": g.code_line, "kernel_name": g.kernel_name, "upats": g.upats,
"uops": uops_strs, "graphs": g.graphs, "diffs": g.diffs, "changed_nodes": g.changed_nodes, "kernel_code": g.kernel_code}
else: jret = [list(map(lambda x:asdict(x[2]), v)) for v in kernels]
kidx, ridx = int(qkernel[0]), int(query["idx"][0])
jret:Any = get_details(contexts[0][kidx], contexts[1][kidx][ridx], kernels[int(qkernel[0])][1][int(query["idx"][0])])
else: jret = kernels
ret, content_type = json.dumps(jret).encode(), "application/json"
elif url.path == "/get_profile" and perfetto_profile is not None: ret, content_type = perfetto_profile, "application/json"
else: status_code = 404
@@ -198,10 +169,11 @@ if __name__ == "__main__":
contexts, profile = load_pickle(args.kernels), load_pickle(args.profile)
# NOTE: this context is a tuple of list[keys] and list[values]
kernels = get_metadata(*contexts) if contexts is not None else []
if getenv("FUZZ_VIZ"):
ret = [get_details(*args) for v in tqdm(kernels) for args in v]
ret = [get_details(contexts[0][i], contexts[1][i][j], args) for i,v in tqdm(enumerate(kernels)) for j,args in enumerate(v[1])]
print(f"fuzzed {len(ret)} rewrite details")
perfetto_profile = to_perfetto(profile) if profile is not None else None