Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-03-07 12:04:38 -08:00
41 changed files with 6861 additions and 345 deletions

View File

@@ -455,8 +455,8 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Insert amdgpu
run: sudo modprobe amdgpu
- name: Remove amdgpu
run: sudo rmmod amdgpu || true
- name: Symlink models and datasets
run: |
mkdir -p weights
@@ -474,10 +474,6 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: setup perflevel
run: |
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
rocm-smi
- name: Train MNIST
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps

View File

@@ -323,8 +323,8 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
- name: Run unit tests
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
- name: Repo line count < 11300 lines
run: MAX_LINE_COUNT=11300 python sz.py
- name: Repo line count < 11500 lines
run: MAX_LINE_COUNT=11500 python sz.py
fuzzing:
name: Fuzzing
@@ -347,7 +347,7 @@ jobs:
testgpuimage:
name: 'GPU IMAGE Tests'
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
timeout-minutes: 10
steps:
- name: Checkout Code
@@ -371,7 +371,7 @@ jobs:
testopenpilot:
name: 'openpilot Compile Tests'
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
timeout-minutes: 10
steps:
- name: Checkout Code
@@ -644,7 +644,7 @@ jobs:
backend: [metal, llvm, cpu]
name: MacOS (${{ matrix.backend }})
runs-on: macos-15
timeout-minutes: 10
timeout-minutes: 20
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -664,6 +664,9 @@ jobs:
run: python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
- name: Run process replay tests
uses: ./.github/actions/process-replay
- name: Run macOS-specific unit test
if: matrix.backend == 'cpu'
run: python3 -m pytest test/unit/test_disk_tensor.py::TestDiskTensor::test_copy_to_cpu_not_truncated
# ****** Windows Tests ******

1
.gitignore vendored
View File

@@ -10,6 +10,7 @@ notebooks
*.so
*.txt
build
!examples/tinychat/assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/
/dist
*.egg-info
/env

View File

@@ -171,6 +171,7 @@ generate_amd() {
extra/hip_gpu_driver/sdma_v6_0_0_pkt_open.h \
extra/hip_gpu_driver/gc_11_0_0_offset.h \
extra/hip_gpu_driver/gc_10_3_0_offset.h \
extra/hip_gpu_driver/sienna_cichlid_ip_offset.h \
--clang-args="-I/opt/rocm/include -x c++" \
-o $BASE/amd_gpu.py
@@ -353,6 +354,12 @@ generate_am() {
extra/amdpci/headers/amdgpu_smu.h \
-o $BASE/am/smu_v13_0_0.py
fixup $BASE/am/smu_v13_0_0.py
clang2py -k cdefstum \
extra/amdpci/headers/hdp_6_0_0_offset.h \
extra/amdpci/headers/hdp_6_0_0_sh_mask.h \
-o $BASE/am/hdp_6_0_0.py
fixup $BASE/am/hdp_6_0_0.py
}
generate_webgpu() {

View File

@@ -851,7 +851,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
optimizer.step()
scheduler.step()
return loss.realize(), global_norm.realize()
# TODO: no to("CPU") here because it blocks and messes the python time
Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr)
return loss, global_norm, optimizer.optimizers[0].lr
@TinyJit
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
@@ -862,7 +864,10 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
for t in [masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss]:
t.to_("CPU")
Tensor.realize(masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss)
return masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss
def train_bert():
# NOTE: pip install tensorflow, wandb required
@@ -1031,47 +1036,49 @@ def train_bert():
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
while train_data is not None and i < train_steps and not achieved:
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
if getenv("TRAIN", 1):
Tensor.training = True
BEAM.value = TRAIN_BEAM
st = time.perf_counter()
GlobalCounters.reset()
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
pt = time.perf_counter()
pt = time.perf_counter()
try:
next_data = next(train_it)
except StopIteration:
next_data = None
try:
next_data = next(train_it)
except StopIteration:
next_data = None
dt = time.perf_counter()
dt = time.perf_counter()
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
loss = loss.item()
device_str = parameters[0].device if isinstance(parameters[0].device, str) else f"{parameters[0].device[0]} * {len(parameters[0].device)}"
loss = loss.item()
lr = lr.item()
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
cl = time.perf_counter()
if BENCHMARK: step_times.append(cl - st)
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {lr:.6f} LR, "
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
if WANDB:
wandb.log({"lr": lr, "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
train_data, next_data = next_data, None
i += 1
train_data, next_data = next_data, None
i += 1
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
if i == BENCHMARK:
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * train_steps / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
# ** eval loop **
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:

View File

@@ -0,0 +1,11 @@
/*!
Pure v3.0.0
Copyright 2013 Yahoo!
Licensed under the BSD License.
https://github.com/pure-css/pure/blob/master/LICENSE
*/
/*!
normalize.css v | MIT License | https://necolas.github.io/normalize.css/
Copyright (c) Nicolas Gallagher and Jonathan Neal
*/
/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}

View File

@@ -0,0 +1,5 @@
net_*
llama3-2.tiktoken
tiktoken.js
tiktoken_bg.wasm
transformer*

View File

@@ -0,0 +1,8 @@
# How to build and run tinychat in browser (WebGPU and WASM)
- `PYTHONPATH=. python examples/tinychat/tinychat-browser/compile.py`
- `./examples/tinychat/tinychat-browser/compile_wasm.sh`
- Prerequisite: [install emscripten](https://emscripten.org/docs/getting_started/downloads.html). This script looks for `~/emsdk/emsdk_env.sh`, adjust this based on your installation.
- `./examples/tinychat/tinychat-browser/make_tiktoken_js.sh`
- Prerequisite: install `npm`, `webpack`.
- `cd examples/tinychat && python -m http.server 7776`
- In browser: open either `localhost:7776/tinychat-browser` (WebGPU), or `localhost:7776/tinychat-browser/?backend=wasm` (WASM)

View File

@@ -0,0 +1,149 @@
import os, json, hashlib, math
from extra.export_model import export_model
from examples.llama3 import build_transformer, Tokenizer
from tinygrad.nn.state import get_state_dict, load_state_dict
from tinygrad import Device, Variable, Tensor, dtypes, TinyJit
from tinygrad.helpers import fetch, Context
from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
def prepare_browser_chunks(model):
# split weights into browser-friendly chunks
state_dict = get_state_dict(model)
del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
metadata = {}
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
split_t_infos = []
for size, name, dtype in t_infos:
if size <= chunk_size:
split_t_infos.append((size, name, dtype, ()))
else: # split large weights into multiple parts
for i in range(0, size, chunk_size):
split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size))))
files = []
# pack weights into files with FFD bin packing
split_t_infos = sorted(split_t_infos, reverse=True)
for info in split_t_infos:
placed = False
for file in files:
if sum(i[0] for i in file) + info[0] <= chunk_size:
if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints
file.append(info)
placed = True
break
if not placed:
files.append([info])
tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"}
for i, file in enumerate(files):
cursor = 0
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer:
for size, name, dtype, offsets in file:
name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1]))
default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]}
weight_metadata = metadata.get(name, default)
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
metadata[name] = weight_metadata
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
data = data if not offsets else data[offsets[0]:offsets[1]]
writer.write(data)
cursor += size
metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos})
for k in metadata:
metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])]
cursor = 0
for i, part in enumerate(metadata[k]["parts"]):
metadata[k]["parts"][i]["target_start_pos"] = cursor
cursor += part["size"]
metadata[k]["size"] = cursor
# compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues
state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []}
hashes = set()
for i in range(len(files)):
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader:
hash = hashlib.sha256(reader.read()).hexdigest()
hashes.add(hash)
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hash})
if len(hashes) != len(files): print(f"WARNING: {len(files)} files were exported, but only {len(hashes)} are unique: something may have gone wrong")
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
metadata = {"metadata": metadata, "metadata_hash": metadata_hash}
with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4)
return metadata
def validate_model(model, tokenizer):
prompt = "yo"
toks = [tokenizer.bos_id]
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("user") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
toks += tokenizer.encode(prompt) + [tokenizer.special_tokens["<|eot_id|>"]]
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("assistant") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
start_pos = 0
run = TinyJit(model.forward)
for tok in toks[:-1]:
run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).realize()
start_pos += 1
tok = toks[-1]
result = ""
expected = "How's it going?"
while True:
tok = run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).item()
start_pos += 1
if tok in tokenizer.stop_tokens or len(result) > len(expected): break
result += tokenizer.decode([tok])
assert result == expected, f"Model validation failed, expected output: {expected}, actual output: {result}"
if __name__=="__main__":
# Export BPE data for use with tiktoken.js
tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path))
bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken")
dump_tiktoken_bpe(mergeable_ranks, bpe_path)
tokenizer = Tokenizer(str(tokenizer_path))
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
Tensor.no_grad = True
max_context=1024
tok = 128000
TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0
start_pos = Variable("start_pos", 0, max_context).bind(0)
model_input = lambda: [Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P]
Device.DEFAULT="CPU"
model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context)
state_dict = get_state_dict(model)
validate_model(model, tokenizer)
model_name = "transformer"
with Context(BEAM=3):
cprog, js_wrapper = export_model(model, "wasm", *model_input(), model_name=model_name)
# ensure consistency with exported weights
js_wrapper = js_wrapper.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
with open(os.path.join(os.path.dirname(__file__), f"{model_name}.c"), "w") as f: f.write(cprog)
with open(os.path.join(os.path.dirname(__file__), "net_clang.js"), "w") as f: f.write(js_wrapper)
Device.DEFAULT="WEBGPU"
# float16 is not yet supported for dawn/Vulkan/NVIDIA stack, see: https://issues.chromium.org/issues/42251215
# therefore for now, we used CLANG to quantize the float16 llama to int8 with float32 scales, then load to WEBGPU
model = build_transformer(model_path, model_size="1B", quantize="int8", max_context=max_context, load_weights=False)
load_state_dict(model, state_dict)
# these were the same before load_state_dict
model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale
validate_model(model, tokenizer)
metadata = prepare_browser_chunks(model) # export weights to disk
with Context(BEAM=3):
prg, input_sizes, output_sizes, state = export_model(model, "webgpu", *model_input(), model_name=model_name, stream_weights=True)
# ensure consistency with exported weights
prg = prg.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as f: f.write(prg)

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env bash
cd "$(dirname "$0")"
# prereq: install emscripten: https://emscripten.org/docs/getting_started/downloads.html
EMSCRIPTEN_PATH=~/emsdk/emsdk_env.sh
source $EMSCRIPTEN_PATH
step="transformer"
initial_memory=6553600
max_memory=1500053504
exported_functions='["_net", "_malloc", "_free", "_set_buf"]'
emcc "${step}.c" \
-O3 -msimd128 -ffast-math -flto \
-o "${step}.js" \
-s MODULARIZE=1 \
-s EXPORT_ES6=1 \
-s EXPORTED_FUNCTIONS="${exported_functions}" \
-s ENVIRONMENT='worker' \
-s FILESYSTEM=0 \
-s EVAL_CTORS \
-s ALLOW_MEMORY_GROWTH=1 \
-s INITIAL_MEMORY="$initial_memory" \
-s MAXIMUM_MEMORY="$max_memory"

View File

@@ -0,0 +1,322 @@
/* define colors */
:root {
--primary-color: #fff;
--secondary-color: #2a2a2a;
--secondary-color-transparent: #ffffff66;
--primary-bg-color: #1a1a1a;
--foreground-color: #f0f0f0;
}
main {
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
place-items: center;
}
.home {
width: 100%;
height: 90%;
margin-bottom: 10rem;
}
.title {
font-size: 3rem;
margin: 1rem 0;
margin-top: 3rem;
}
.histories-container-container {
width: 100%;
max-height: 75%;
position: relative;
}
.histories-container {
overflow-y: auto;
overflow-x: hidden;
width: 100%;
height: 100%;
display: flex;
flex-direction: column;
gap: 1rem;
align-items: center;
margin: 0;
padding: 3rem 1rem;
}
.histories-start {
height: 3rem;
width: 100%;
z-index: 999;
top: 0;
position: absolute;
background: linear-gradient(
180deg,
var(--primary-bg-color) 0%,
transparent 100%
);
}
.histories-end {
height: 3rem;
width: 100%;
z-index: 999;
bottom: 0;
position: absolute;
background: linear-gradient(
0deg,
var(--primary-bg-color) 0%,
transparent 100%
);
}
.history {
padding: 1rem;
width: 100%;
max-width: 40rem;
background-color: var(--secondary-color);
border-radius: 10px;
border-left: 2px solid var(--primary-color);
cursor: pointer;
transform: translateX(calc(1px * var(--tx, 0)));
opacity: var(--opacity, 1);
}
.history:hover {
background-color: var(--secondary-color);
}
.history-delete-button {
position: absolute;
top: 0;
right: 0;
padding: 0.5rem;
margin: 0;
outline: none;
border: none;
background-color: var(--secondary-color);
color: var(--foreground-color);
border-radius: 0 0 0 10px;
cursor: pointer;
transition: 0.2s;
}
.history-delete-button:hover {
background-color: var(--secondary-color);
padding: 0.75rem;
}
.messages {
overflow-y: auto;
height: 100%;
width: 100%;
max-width: 1200px;
display: flex;
flex-direction: column;
gap: 1rem;
align-items: center;
padding-top: 1rem;
padding-bottom: 11rem;
}
.message {
max-width: 75%;
padding: 0.5rem 1rem;
border-radius: 20px;
}
.message-role-assistant {
background-color: var(--secondary-color);
margin-right: auto;
color: #fff;
}
.message-role-user {
margin-left: auto;
background-color: var(--primary-color);
color: #000;
}
.message > pre {
white-space: pre-wrap;
}
.hljs {
width: 100%;
position: relative;
border-radius: 10px;
/* wrap code blocks */
white-space: pre-wrap;
}
/* put clipboard button in the top right corner of the code block */
.clipboard-button {
position: absolute;
top: 0;
right: 0;
padding: 0.5rem;
margin: 0;
outline: none;
border: none;
background-color: var(--secondary-color);
color: var(--foreground-color);
border-radius: 0 0 0 10px;
cursor: pointer;
transition: 0.2s;
}
.clipboard-button:hover {
background-color: var(--secondary-color);
padding: 0.75rem;
}
.input-container {
position: absolute;
bottom: 0;
/* linear gradient from background-color to transparent on the top */
background: linear-gradient(
0deg,
var(--primary-bg-color) 55%,
transparent 100%
);
width: 100%;
max-width: 1200px;
display: flex;
flex-direction: column;
justify-content: center;
align-items: center;
z-index: 999;
}
.input-performance {
margin-top: 4rem;
display: flex;
flex-direction: row;
gap: 1rem;
}
.input-performance-point {
display: flex;
flex-direction: row;
place-items: center;
gap: 0.5rem;
}
.input-performance-point > p {
height: 1rem;
line-height: normal;
}
.input {
width: 90%;
min-height: 3rem;
flex-shrink: 0;
display: flex;
flex-direction: row;
justify-content: center;
gap: 0.5rem;
align-items: flex-end;
margin-bottom: 2rem;
}
.input-form {
width: 100%;
padding: 1rem;
min-height: 3rem;
max-height: 8rem;
background-color: var(--secondary-color);
color: var(--foreground-color);
border-radius: 10px;
border: none;
resize: none;
outline: none;
}
.mobile .input-form { /* prevent auto-zoom on touching prompt box */
font-size: 16px;
}
.input-button {
height: 3rem;
width: 4rem;
background-color: var(--primary-color);
color: var(--secondary-color);
border-radius: 10px;
padding: 0.5rem;
cursor: pointer;
}
.input-button:hover {
background-color: var(--secondary-color-transparent);
}
.input-button:disabled {
background-color: var(--secondary-color);
cursor: not-allowed;
}
/* wrap text */
p {
white-space: pre-wrap;
}
/* fonts */
.megrim-regular {
font-family: monospace;
font-weight: 400;
font-style: normal;
}
.monospace {
font-family: monospace;
}
.loading-bar {
display: flex;
flex-direction: row;
align-items: center;
gap: 0.5rem;
width: 100%;
min-height: 3rem;
margin-bottom: 2rem;
}
.loading-text {
color: var(--foreground-color);
font-size: 1rem;
white-space: nowrap;
}
#progress-percentage {
color: var(--foreground-color);
font-size: 1rem;
white-space: nowrap;
}
.progress-bar {
flex-grow: 1;
height: 0.5rem;
background-color: var(--secondary-color);
border-radius: 5px;
overflow: hidden;
position: relative;
}
.progress {
width: 0%;
height: 100%;
background-color: var(--primary-color);
transition: width 0.2s ease-in-out;
}

View File

@@ -0,0 +1,182 @@
<!DOCTYPE html>
<head>
<title>tinychat</title>
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="icon" href="../favicon.svg" type="image/svg+xml">
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
<script defer src="../assets/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
<script defer src="../assets/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
<script src="../assets/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
<script src="../assets/unpkg.com/marked@13.0.0/marked.min.js"></script>
<script src="../assets/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
<script src="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
<script src="index.js"></script>
<link rel="stylesheet" href="../assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css">
<link rel="stylesheet" href="../assets/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"
integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A=="
crossorigin="anonymous" referrerpolicy="no-referrer" />
<link rel="stylesheet" href="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css">
<link rel="stylesheet" href="index.css">
<link rel="stylesheet" href="../common.css">
</head>
<body>
<main x-data="state" x-init="console.log(endpoint)">
<div class="home centered" x-show="home === 0" x-transition x-effect="
$refs.inputForm.focus();
if (home === 1) setTimeout(() => home = 2, 100);
if (home === -1) setTimeout(() => home = 0, 100);
" @popstate.window="
if (home === 2) {
cancelGeneration = true;
if (maxContextReached) generating = false;
if (!generating) cstate = { time: null, messages: [] };
home = -1;
time_till_first = 0;
tokens_per_second = 0;
total_tokens = 0;
}
">
<h1 class="title megrim-regular">tinychat</h1>
<div class="histories-container-container">
<template x-if="histories.length">
<div class="histories-start"></div>
</template>
<div class="histories-container" x-intersect="
$el.scrollTo({ top: 0, behavior: 'smooth' });
">
<template x-for="_state in histories.toSorted((a, b) => b.time - a.time)">
<div x-data="{ otx: 0, trigger: 75 }" class="history" @click="
cstate = _state;
updateTotalTokens(cstate.messages);
home = 1;
// ensure that going back in history will go back to home
window.history.pushState({}, '', window.TINYCHAT_ROOT || '/');
" @touchstart="
otx = $event.changedTouches[0].clientX;
" @touchmove="
$el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
$el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
" @touchend="
if (Math.abs($event.changedTouches[0].clientX - otx) > trigger) removeHistory(_state);
$el.style.setProperty('--tx', 0);
$el.style.setProperty('--opacity', 1);
">
<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
<p x-text="$truncate(_state.messages[0].content, 80)"></p>
<!-- delete button -->
<button class="history-delete-button" @click.stop="removeHistory(_state);">
<i class=" fas fa-trash"></i>
</button>
</div>
</template>
</div>
<template x-if="histories.length">
<div class="histories-end"></div>
</template>
</div>
</div>
<div x-ref="messages" class="messages" x-init="
$watch('cstate', value => {
$el.innerHTML = '';
value.messages.forEach(({ role, content }) => {
const div = document.createElement('div');
div.className = `message message-role-${role}`;
try {
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
} catch (e) {
console.log(content);
console.error(e);
}
// add a clipboard button to all code blocks
const codeBlocks = div.querySelectorAll('.hljs');
codeBlocks.forEach(codeBlock => {
const button = document.createElement('button');
button.className = 'clipboard-button';
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
button.onclick = () => {
// navigator.clipboard.writeText(codeBlock.textContent);
const range = document.createRange();
range.setStartBefore(codeBlock);
range.setEndAfter(codeBlock);
window.getSelection()?.removeAllRanges();
window.getSelection()?.addRange(range);
document.execCommand('copy');
window.getSelection()?.removeAllRanges();
button.innerHTML = '<i class=\'fas fa-check\'></i>';
setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
};
codeBlock.appendChild(button);
});
$el.appendChild(div);
});
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
});
" x-intersect="
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
" x-show="home === 2" x-transition>
</div>
<div class="input-container">
<div class="input-performance">
<span class="input-performance-point">
<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
<p class="megrim-regular">SEC TO FIRST TOKEN</p>
</span>
<span class="input-performance-point">
<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
<p class="megrim-regular">TOKENS/SEC</p>
</span>
<span class="input-performance-point">
<p class="monospace" x-text="total_tokens"></p>
<p class="megrim-regular">TOKENS</p>
</span>
</div>
<div class="loading-bar" x-show="loadingMessage !== ''">
<p class="loading-text" id="loading-message">Loading:</p>
<span id="progress-percentage">0%</span>
<div class="progress-bar">
<div class="progress"></div>
</div>
</div>
<div class="input" x-show="loadingMessage === ''">
<textarea x-ref="inputForm" id="input-form" class="input-form" autofocus rows=1 x-autosize
:placeholder="generating ? placeholderText : 'Say something'" :disabled="generating" @input="
home = (home === 0) ? 1 : home
if (cstate.messages.length === 0 && $el.value === '') home = -1;
if ($el.value !== '') {
const messages = [...cstate.messages];
messages.push({ role: 'user', content: $el.value });
updateTotalTokens(messages);
} else {
if (cstate.messages.length === 0) total_tokens = 0;
else updateTotalTokens(cstate.messages);
}
" x-effect="
console.log(generating);
if (!generating) $nextTick(() => {
$el.focus();
setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
});
" @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)"></textarea>
<button class="input-button" :disabled="generating" @click="await handleSend()">
<i class="fas" :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'"></i>
</button>
</div>
</div>
</main>
</body>
</html>

View File

@@ -0,0 +1,927 @@
window.TINYCHAT_ROOT = "/tinychat-browser/";
const queryParams = new URLSearchParams(window.location.search);
const normalizedParams = Object.fromEntries([...queryParams].map(([key, value]) => [key.toUpperCase(), value.toUpperCase()]));
window.BACKEND = (normalizedParams["BACKEND"] === "WASM") ? "WASM" : "WebGPU";
const isMobileAgent = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent);
const hasTouchScreen = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
window.isMobile = isMobileAgent || hasTouchScreen;
if (window.isMobile) document.documentElement.classList.add('mobile'); // prevent annoying auto-zoom when entering prompt on mobile
// MODEL_BASE_URL is where the weights are hosted, WEBGPU_EXPORT is the JS-wrapped WebGPU code exported from tinygrad
window.PC_MODEL_BASE_URL = ".";
window.PC_WEBGPU_EXPORT = './net.js'
window.PC_MAX_CONTEXT = 1024;
window.MOBILE_MODEL_BASE_URL = ".";
window.MOBILE_WEBGPU_EXPORT = './net.js'
window.MOBILE_MAX_CONTEXT = 1024;
const tiktokenReady = (async () => {
const { init, get_encoding, Tiktoken, load } = await import('./tiktoken.js');
window.Tiktoken = Tiktoken;
window.tiktokenInit = init;
window.tiktokenGetEncoding = get_encoding;
window.tiktokenLoad = load;
})();
async function getDevice() {
let adapter;
try {
adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
throw new Error("No WebGPU adapter found");
}
} catch(error) {
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
throw error;
}
const requiredLimits = {};
const maxBufferSize = 322122544;
requiredLimits.maxStorageBufferBindingSize = maxBufferSize;
requiredLimits.maxBufferSize = maxBufferSize;
requiredLimits.maxComputeInvocationsPerWorkgroup = 512; // may need to vary based on what the WEBGPU backend produces
try {
return await adapter.requestDevice({ requiredLimits });
} catch(error) {
this.loadingMessage = "Loading WASM (WebGPU error):";
throw error;
}
};
// copied from examples/webgpu/stable_diffusion/index.html
function initDb() {
return new Promise((resolve, reject) => {
let db;
const request = indexedDB.open('tinydb', 1);
request.onerror = (event) => {
console.error('Database error:', event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
db = event.target.result;
console.log("Db initialized.");
resolve(db);
};
request.onupgradeneeded = (event) => {
db = event.target.result;
if (!db.objectStoreNames.contains('tensors')) {
db.createObjectStore('tensors', { keyPath: 'id' });
}
};
});
}
// copied from examples/webgpu/stable_diffusion/index.html
function readTensorFromDb(db, id) {
return new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readonly');
const store = transaction.objectStore('tensors');
const request = store.get(id);
transaction.onabort = (event) => {
console.log("Transaction error while reading tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = (event) => {
const result = event.target.result;
if (result) {
resolve(result);
} else {
resolve(null);
}
};
request.onerror = (event) => {
console.error('Tensor retrieve failed: ', event.target.error);
resolve(null);
};
});
}
function getAllKeysFromDb(db) {
return new Promise((resolve, reject) => {
if (db == null) {resolve([]);}
const transaction = db.transaction(['tensors'], 'readonly');
const store = transaction.objectStore('tensors');
const request = store.getAllKeys();
transaction.onabort = (event) => {
console.log("Transaction error while reading IndexedDB keys: " + event.target.error);
resolve([]);
};
request.onsuccess = function (event) {resolve(event.target.result);};
request.onerror = (event) => {
console.error('Retrieval of IndexedDB keys failed: ', event.target.error);
resolve([]);
};
});
}
// modified from examples/webgpu/stable_diffusion/index.html
function saveTensorToDb(db, id, tensor) {
return readTensorFromDb(db, id).then((result) => {
if (!result) {
new Promise((resolve, reject) => {
if (db == null) {
resolve(null);
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.put({ id: id, content: tensor });
transaction.onabort = (event) => {
console.log("Transaction error while saving tensor: " + event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log('Tensor saved successfully.');
resolve();
};
request.onerror = (event) => {
console.error('Tensor save failed:', event.target.error);
resolve(null);
};
});
} else {
return null;
}
}).catch(()=> null);
}
function deleteTensorFromDb(db, id) {
return new Promise((resolve, reject) => {
if (db == null) {
console.error("Database is not initialized.");
resolve(null);
return;
}
const transaction = db.transaction(['tensors'], 'readwrite');
const store = transaction.objectStore('tensors');
const request = store.delete(id);
transaction.oncomplete = () => {
console.log(`Tensor with ID '${id}' deleted successfully.`);
resolve();
};
transaction.onerror = (event) => {
console.error("Transaction error while deleting tensor:", event.target.error);
resolve(null);
};
request.onerror = (event) => {
console.error('Tensor deletion failed:', event.target.error);
resolve(null);
};
request.onsuccess = () => {
console.log(`Delete request for tensor with ID '${id}' succeeded.`);
};
});
}
function makeProgress(total) {
let acc = 0;
const ret = function progress(amount, message) {
if (amount >= 0) { // allow updating message only
acc += amount;
const percentage = total ? Math.trunc((acc / total) * 100) : 0;
document.querySelector('.progress').style.width = `${percentage}%`;
document.getElementById('progress-percentage').textContent = `${percentage}%`;
}
if (message) {
this.loadingMessage = message;
document.getElementById('loading-message').textContent = this.loadingMessage;
}
}.bind(this);
ret.total = total;
return ret;
}
function sendMessageToWorker(worker, message) {
return new Promise((resolve, reject) => {
const onMessage = (event) => {
resolve(event.data);
worker.removeEventListener('message', onMessage);
worker.removeEventListener('error', onError);
};
const onError = (error) => {
reject(error);
worker.removeEventListener('message', onMessage);
worker.removeEventListener('error', onError);
};
worker.addEventListener('message', onMessage);
worker.addEventListener('error', onError);
if (message.header === "token") worker.postMessage(message.data);
else if (message.header === "load_state_dict") {
if (message.data === "done") worker.postMessage(message.data);
else worker.postMessage(message.data, message.data.map(file => file.bytes.buffer));
}
else if (message.header === "init") worker.postMessage("init");
});
}
async function load_state_dict (data, device, progress) {
let state_dict = data.metadata.state_dict;
let completed = 0;
// modified from examples/webgpu/stable_diffusion/index.html getProgressDlForPart
const loadPart = async (part) => {
const response = await fetch(part);
const res = new Response(new ReadableStream({
async start(controller) {
const reader = response.body.getReader();
for (;;) {
const { done, value } = await reader.read();
if (done) break;
progress(value.byteLength);
controller.enqueue(value);
}
controller.close();
},
}));
return res.arrayBuffer();
};
let db = await initDb();
const getPart = async(filename, hash) => {
let part = await readTensorFromDb(db, hash);
if (part) {
console.log(`Cache hit: ${filename}, hash: ${hash}`);
progress(part.content.byteLength);
return Promise.resolve(part.content);
} else {
console.log(`Cache miss: ${filename}, hash: ${hash}`);
return loadPart(`${window.MODEL_BASE_URL}/${filename}`);
}
}
const correctHashes = data.metadata.files.map(file => file.hash)
// delete unused cached buffers to free disk space -- if we update weights, user will otherwise have obsolete cached buffers
const dbKeys = await getAllKeysFromDb(db);
const correctHashesSet = new Set(correctHashes);
const notInCorrectHashes = dbKeys.filter(key => !correctHashesSet.has(key));
// await these right before starting to save new stuff
const deletionPromises = notInCorrectHashes.map(async (hash) => deleteTensorFromDb(db, hash));
// instantiates empty weight buffers on WebGPU, attaches buffers to state_dict
let model;
if (window.BACKEND === "WebGPU") {
//model = await transformer().setup(device, state_dict, progress);
model = await transformer.setupNet(device, state_dict);
progress(0.15 * progress.total);
}
else if (window.BACKEND === "WASM") {
progress(0.02 * progress.total);
model = new Worker(`./worker.js?version=${Date.now()}`);
await sendMessageToWorker(model, {header: "init"});
progress(0.02 * progress.total);
progress(0.11 * progress.total);
}
const downloaded = [];
const triggerChainDownload = async (toDownload) => {
const numDownloaders = window.isMobile ? 4 : toDownload.length; // TODO: dynamically base this on DL file size? current assumption is 16 MiB chunks
const chainDownload = async() => {
const file = toDownload.shift();
loadPart(`${window.MODEL_BASE_URL}/${file.name}`) // triggers download
.then(async (arraybuf) => {
downloaded.push({ ...file, bytes: new Uint8Array(arraybuf)});
// pause downloads if further processing is a bottleneck
while (toDownload.length && downloaded.length >= numDownloaders) await new Promise(resolve => setTimeout(resolve, 5));
if (toDownload.length && downloaded.length < numDownloaders) chainDownload(); // start next download
})
}
for (let i=0; i<numDownloaders; i++) if (toDownload.length) chainDownload();
}
const loadFileToStateDict = async(file) => {
if (window.BACKEND === "WebGPU") {
for (const part of file.parts) {
if (part.empty) continue;
part.bytes = (part.size === file.bytes.length) ? file.bytes : file.bytes.slice(part.file_start_pos, part.file_start_pos + part.size);
device.queue.writeBuffer(state_dict[part.key].bytes, part.target_start_pos, part.bytes); // improves stability over mappedAtCreation writing
part.bytes = null;
}
}
else if (window.BACKEND === "WASM") {
await sendMessageToWorker(model, {header: "load_state_dict", data: [file]});
}
file.bytes = null;
}
if (window.BACKEND === "WebGPU") { // contiguous loading not needed for WebGPU stability
const files = data.tensor_file_groups.flatMap(obj => obj.files);
data.tensor_file_groups = [{contiguous: false, files: files}];
}
for (const group of data.tensor_file_groups) {
const contiguous = group.contiguous;
const files = group.files;
const tensor_file_indices = files.map(file => file.index);
const contiguousFiles = [];
const fileHashes = new Set(files.map(file => file.hash));
const cachedFileHashes = new Set(dbKeys.filter(key => fileHashes.has(key)));
const cachedFiles = files.filter(file => cachedFileHashes.has(file.hash));
const toDownload = files.filter(file => !cachedFileHashes.has(file.hash));
triggerChainDownload(toDownload);
const loadDelay = 5;
await Promise.all(deletionPromises);
while (completed < files.length) {
const start = performance.now();
// prioritize files from downloaded queue, so we can continue downloading more files
if (downloaded.length) {
const file = downloaded.shift();
await saveTensorToDb(db, file.hash, file.bytes); // for wasm, must await to prevent race between indexedDB and transfer to worker
if (!contiguous) await loadFileToStateDict(file);
else contiguousFiles.push(file);
completed += 1;
}
else if (!downloaded.length && cachedFiles.length) {
const file = cachedFiles.shift();
file.bytes = await getPart(file.name, file.hash); // reads data from IndexedDB
if (!contiguous) await loadFileToStateDict(file);
else contiguousFiles.push(file);
completed += 1;
}
const end = performance.now();
const elapsed = end - start;
if (elapsed < loadDelay) await new Promise(resolve => setTimeout(resolve, loadDelay - elapsed));
}
if (contiguous) {
const orderMap = tensor_file_indices.reduce((acc, id, index) => {acc[id] = index; return acc;}, {});
contiguousFiles.sort((a, b) => orderMap[a.index] - orderMap[b.index]); // glue files together in the right order
await sendMessageToWorker(model, {header: "load_state_dict", data: contiguousFiles});
}
completed = 0;
}
// initialize empty kv_caches, which were part of exported model's state_dict, but which we didn't want to package/download
if (window.BACKEND === "WASM") {
for (const [k, v] of Object.entries(state_dict).filter(([_, v]) => v.empty === true)) {
v.parts[0].file_start_pos = 0;
const file = { parts: v.parts, size: v.size, bytes: new Uint8Array(v.size).fill(0) };
await loadFileToStateDict(file);
}
}
return model;
};
document.addEventListener("alpine:init", () => {
Alpine.data("state", () => ({
// loadingMessage updates the user on page load progress, including weights download and decompression
// if loadingMessage is not '', then prompt box will be hidden: this is default behavior on page load
placeholderText: "Generating...",
loadingMessage: `Loading ${window.BACKEND} model:`,
// model
nets: {},
tokenizer: null,
max_context: 1024,
lastSeenToks: [],
progress: null,
async init() {
var device = null;
var webgpuErrorMessage = null;
if (window.BACKEND === "WebGPU") {
try {
device = await getDevice.call(this);
console.log("WebGPU device initialized");
} catch (error) {
window.BACKEND = "WASM";
console.log(`error: ${error}\nFailed to launch WebGPU. Loading WASM model instead...`); // return;
webgpuErrorMessage = this.loadingMessage;
}
}
window.MODEL_BASE_URL = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MODEL_BASE_URL : window.MOBILE_MODEL_BASE_URL;
this.max_context = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MAX_CONTEXT : window.MOBILE_MAX_CONTEXT;
const kernelsReady = (async () => {
if (window.BACKEND === "WASM") {var exports = await import(`./net_clang.js?version=${Date.now()}`);}
else if (window.BACKEND === "WebGPU" && !window.isMobile) {var exports = await import(`${PC_WEBGPU_EXPORT}?version=${Date.now()}`);}
else if (window.BACKEND === "WebGPU" && window.isMobile) {var exports = await import(`${MOBILE_WEBGPU_EXPORT}?version=${Date.now()}`);}
self.transformer = exports.default;
})();
const response = await fetch(`${window.MODEL_BASE_URL}/net_metadata.json`);
// TODO: cache metadata (and everything else, including tokenizer)
// TODO: use service worker to reload page when offline
const data = await response.json();
data.metadata.files = data.metadata.files.map((file, index) => ({...file, index}));
const state_dict = data.metadata.state_dict;
/*
- allocating memory to WASM on mobile has longstanding issues: https://github.com/WebAssembly/design/issues/1397
- the below pattern, while yielding a succesfully-functioning model when it doesn't crash, causes regular crashes on iOS Safari (iphone 15 iOS 18.3):
- call WASM malloc (to fit all tensors, or one per tensor) for all tensors up front, then load tensor byte chunks into the buffers in random order
- the below pattern has been stable on iOS Safari (iphone 15 iOS 18.3):
- call only one WASM malloc at a time before filling the allocated bytes, as small as possible (malloc up to 256 MiB has been tested)
- fill the malloc'd memory in linear order from start to end (what has been tested is calling wasm.HEAPU8.set on 16 MiB chunks from start to end)
- use ALLOW_MEMORY_GROWTH=1 in wasm compilation, minimize initial memory
- additional considerations affecting loading design, for WASM:
- it seems that copying bytes into wasm memory cannot be zero-copy without sharedarraybuffer, which isn't currently used due to increased hosting complexity
- non-zero copies create memory pressure, which is not reliably capped because of lack of control over garbage collection
- to minimize peak memory pressure if GC is delayed, we process (i.e. download + copy into WASM) large tensors (> 16 MiB) one at a time, in descending size order
*/
data.tensor_file_groups = []; // see above: for WASM, limit processing of multi-file Tensors to one at a time, in descending order based on Tensor size
const unsplit_tensors = [];
const sortedEntries = Object.entries(state_dict).sort(([, objA], [, objB]) => objB.size - objA.size);
let totalSize = 0;
const seen = new Set();
for (const [k,v] of sortedEntries) {
const files_in_tensor = [];
for (const part of v.parts) {
part.key = k;
if (part.empty) state_dict[k].empty = true; // assumes no other parts of this weight exist and are non-empty
else {
const file = data.metadata.files[part.file];
if (!seen.has(file.index)) {
seen.add(file.index);
files_in_tensor.push(file);
}
totalSize += part.size;
part.dtype = v.dtype;
if (!data.metadata.files[part.file].parts) data.metadata.files[part.file].parts = [];
data.metadata.files[part.file].size ??= 0;
data.metadata.files[part.file].size += part.size;
data.metadata.files[part.file].parts.push(part);
}
}
if (files_in_tensor.length > 1) data.tensor_file_groups.push({contiguous: true, files: files_in_tensor}); // [tensorN_file0, tensorN_file1, ...]
else if (files_in_tensor.length > 0) unsplit_tensors.push(files_in_tensor[0]);
}
data.tensor_file_groups.push({contiguous: false, files: unsplit_tensors});
data.totalSize = totalSize;
totalSize = totalSize / 0.8; // give space in progress bar for initializing model bufs, and tokenizer
this.progress = makeProgress.call(this, totalSize); // creates closure with totalSize
try {
this.progress(0.01 * totalSize, "Loading tokenizer:");
const wasmResponse = await fetch(`${window.MODEL_BASE_URL}/tiktoken_bg.wasm`);
this.progress(0.01 * totalSize);
const wasmBytes = await wasmResponse.arrayBuffer();
await tiktokenReady;
await window.tiktokenInit((imports) => WebAssembly.instantiate(wasmBytes, imports));
this.progress(0.01 * totalSize);
this.tokenizer = await createTokenizer(`${window.MODEL_BASE_URL}/llama3-2.tiktoken`);
const tokenizer_works = (new TextDecoder().decode(this.tokenizer.decode(this.tokenizer.encode("hello world"))) === "hello world");
console.log("tokenizer works:", tokenizer_works)
this.progress(0.01 * totalSize);
} catch (error) {this.progress(-1, `Error launching tokenizer: ${error}`); console.log(error); return;}
try {
const loadModelMessage = (webgpuErrorMessage) ? webgpuErrorMessage : `Loading ${window.BACKEND} model:`
this.progress(0, loadModelMessage);
await kernelsReady;
const model = await load_state_dict(data, device, this.progress);
if (window.BACKEND === "WebGPU") {
this.nets = {"transformer": model};
}
else if (window.BACKEND === "WASM") {
const msg = await sendMessageToWorker(model, {header: "load_state_dict", data: "done"});
this.nets = {"transformer": async (tok, start_pos) => sendMessageToWorker(model, {header: "token", data: [tok, start_pos]})};
}
this.progress(0.01 * totalSize, `Launching ${window.BACKEND} model:`);
this.loadingMessage = ""; // Triggers removal of loading bar, display of prompt box
} catch (error) {this.progress(-1, `Error launching model: ${error}`); console.log(error); return;}
},
// current state
cstate: {
time: null,
messages: [],
},
// historical state
histories: JSON.parse(localStorage.getItem("histories")) || [],
home: 0,
generating: false,
maxContextReached: false,
cancelGeneration: false,
endpoint: `${window.location.origin}/v1`,
// performance tracking
time_till_first: 0,
tokens_per_second: 0,
total_tokens: 0,
max_context: 0,
removeHistory(cstate) {
const index = this.histories.findIndex((state) => {
return state.time === cstate.time;
});
if (index !== -1) {
this.histories.splice(index, 1);
localStorage.setItem("histories", JSON.stringify(this.histories));
}
},
async handleSend() {
const el = document.getElementById("input-form");
const value = el.value.trim();
if (!value) return;
if (this.generating) return;
this.maxContextReached = false;
this.placeholderText = "Generating...";
this.generating = true;
this.cancelGeneration = false;
if (this.home === 0) this.home = 1;
// ensure that going back in history will go back to home
window.history.pushState({}, "", window.TINYCHAT_ROOT || "/");
// add message to list
this.cstate.messages.push({ role: "user", content: value });
// clear textarea
el.value = "";
el.style.height = "auto";
el.style.height = el.scrollHeight + "px";
// reset performance tracking
const prefill_start = Date.now();
let start_time = 0;
let tokens = 0;
this.tokens_per_second = 0;
let gottenFirstChunk = false;
try {
for await (
const chunk of this.openaiChatCompletion(this.cstate.messages)
) {
if (!gottenFirstChunk) {
this.cstate.messages.push({ role: "assistant", content: "" });
gottenFirstChunk = true;
}
// add chunk to the last message
// TODO: handle localStorage overflow
// possible example: this.cstate.messages[...] was undefined when trying to prompt within an old cstate (chat session)
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
// calculate performance tracking
tokens += 1;
this.total_tokens += 1;
if (start_time === 0) {
start_time = Date.now();
this.time_till_first = start_time - prefill_start;
} else {
const diff = Date.now() - start_time;
if (diff > 0) {
this.tokens_per_second = tokens / (diff / 1000);
}
}
this.checkMaxContext(this.total_tokens);
if (this.cancelGeneration) break;
}
} finally {
// update the state in histories or add it if it doesn't exist
const index = this.histories.findIndex((cstate) => {
return cstate.time === this.cstate.time;
});
this.cstate.time = Date.now();
if (index !== -1) {
// update the time
this.histories[index] = this.cstate;
} else {
this.histories.push(this.cstate);
}
// update in local storage
localStorage.setItem("histories", JSON.stringify(this.histories));
if (!this.maxContextReached) this.generating = false;
if (this.cancelGeneration && !this.maxContextReached) this.cstate = { time: null, messages: [] };
}
},
async handleEnter(event) {
// if shift is not pressed
if (!event.shiftKey) {
event.preventDefault();
await this.handleSend();
}
},
updateTotalTokens(messages) {
try {
let toks = [this.tokenizer.bos_id];
messages.forEach((message) => {
if (!message.role || !message.content) {
throw new Error("Each message must have a 'role' and 'content' property.");
}
toks = toks.concat(this.tokenizer.encodeMessage(message.role, message.content));
if (messages.length > 0 && messages[messages.length - 1].role === "user") {
toks = toks.concat(this.tokenizer.encodeRole("assistant"));
}
this.total_tokens = toks.length;
});
} catch (error) {
console.error("Error updating total tokens:", error);
}
},
checkMaxContext(num_tokens) {
if (num_tokens >= this.max_context) {
this.cancelGeneration = true;
this.maxContextReached = true;
this.placeholderText = `Max context reached: ${this.max_context} tokens`;
}
},
async *openaiChatCompletion(messages) {
let tokens = [this.tokenizer.bos_id];
for (const message of messages) {
tokens = tokens.concat(this.tokenizer.encodeMessage(message.role, message.content));
}
tokens = tokens.concat(this.tokenizer.encodeRole("assistant"));
this.checkMaxContext(tokens.length); // don't waste time prefilling if we know we're over the token limit
let startPos = 0
const prefillToks = tokens.slice(0, -1);
// Skip the largest possible sequence of tokens already represented at the beginning of the model's kv caches
for (let i=0; i <= prefillToks.length; i++) {
startPos = i;
if (i == prefillToks.length) break;
if (i == this.lastSeenToks.length) break;
if (prefillToks[i] !== this.lastSeenToks[i]) break;
}
//this.lastSeenToks = prefillToks;
//prefillToks = prefillToks.slice(startPos);
const unprocessedPrefillToks = prefillToks.slice(startPos);
this.lastSeenToks = prefillToks.slice(0, startPos);
this.progress = makeProgress(unprocessedPrefillToks.length);
this.loadingMessage = (window.BACKEND === "WebGPU") ? "Reading input:" : "Loading (enable WebGPU for speed):";
this.progress(0, this.loadingMessage);
for (const tok of unprocessedPrefillToks) {
if (this.cancelGeneration) {this.loadingMessage=""; return;}
if (window.BACKEND === "WebGPU") {await this.nets["transformer"](new Int32Array([tok]), new Int32Array([startPos]));}
else {await this.nets["transformer"](tok, startPos);}
this.lastSeenToks.push(tok)
startPos += 1;
this.progress(1);
}
this.loadingMessage = ""; // hides progress bar
let lastTok = tokens[tokens.length - 1];
while (true) {
if (window.BACKEND === "WebGPU") {var tok = await this.nets["transformer"](new Int32Array([lastTok]), new Int32Array([startPos])); tok = tok[0][0];}
else {var tok = await this.nets["transformer"](lastTok, startPos);}
this.lastSeenToks.push(lastTok); // lets us skip prefilling with these tokens at the next prompt in this chain
startPos += 1;
lastTok = tok;
if (this.tokenizer.stop_tokens.has(lastTok)) break;
yield new TextDecoder().decode(this.tokenizer.decode([lastTok]));
}
},
}));
});
const { markedHighlight } = globalThis.markedHighlight;
marked.use(markedHighlight({
langPrefix: "hljs language-",
highlight(code, lang, _info) {
const language = hljs.getLanguage(lang) ? lang : "plaintext";
return hljs.highlight(code, { language }).value;
},
}));
// **** eventsource-parser ****
class EventSourceParserStream extends TransformStream {
constructor() {
let parser;
super({
start(controller) {
parser = createParser((event) => {
if (event.type === "event") {
controller.enqueue(event);
}
});
},
transform(chunk) {
parser.feed(chunk);
},
});
}
}
function createParser(onParse) {
let isFirstChunk;
let buffer;
let startingPosition;
let startingFieldLength;
let eventId;
let eventName;
let data;
reset();
return {
feed,
reset,
};
function reset() {
isFirstChunk = true;
buffer = "";
startingPosition = 0;
startingFieldLength = -1;
eventId = void 0;
eventName = void 0;
data = "";
}
function feed(chunk) {
buffer = buffer ? buffer + chunk : chunk;
if (isFirstChunk && hasBom(buffer)) {
buffer = buffer.slice(BOM.length);
}
isFirstChunk = false;
const length = buffer.length;
let position = 0;
let discardTrailingNewline = false;
while (position < length) {
if (discardTrailingNewline) {
if (buffer[position] === "\n") {
++position;
}
discardTrailingNewline = false;
}
let lineLength = -1;
let fieldLength = startingFieldLength;
let character;
for (
let index = startingPosition;
lineLength < 0 && index < length;
++index
) {
character = buffer[index];
if (character === ":" && fieldLength < 0) {
fieldLength = index - position;
} else if (character === "\r") {
discardTrailingNewline = true;
lineLength = index - position;
} else if (character === "\n") {
lineLength = index - position;
}
}
if (lineLength < 0) {
startingPosition = length - position;
startingFieldLength = fieldLength;
break;
} else {
startingPosition = 0;
startingFieldLength = -1;
}
parseEventStreamLine(buffer, position, fieldLength, lineLength);
position += lineLength + 1;
}
if (position === length) {
buffer = "";
} else if (position > 0) {
buffer = buffer.slice(position);
}
}
function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
if (lineLength === 0) {
if (data.length > 0) {
onParse({
type: "event",
id: eventId,
event: eventName || void 0,
data: data.slice(0, -1),
// remove trailing newline
});
data = "";
eventId = void 0;
}
eventName = void 0;
return;
}
const noValue = fieldLength < 0;
const field = lineBuffer.slice(
index,
index + (noValue ? lineLength : fieldLength),
);
let step = 0;
if (noValue) {
step = lineLength;
} else if (lineBuffer[index + fieldLength + 1] === " ") {
step = fieldLength + 2;
} else {
step = fieldLength + 1;
}
const position = index + step;
const valueLength = lineLength - step;
const value = lineBuffer.slice(position, position + valueLength).toString();
if (field === "data") {
data += value ? "".concat(value, "\n") : "\n";
} else if (field === "event") {
eventName = value;
} else if (field === "id" && !value.includes("\0")) {
eventId = value;
} else if (field === "retry") {
const retry = parseInt(value, 10);
if (!Number.isNaN(retry)) {
onParse({
type: "reconnect-interval",
value: retry,
});
}
}
}
}
const BOM = [239, 187, 191];
function hasBom(buffer) {
return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
}
const PAT_STR = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
async function createTokenizer(bpeUrl) {
const num_base_tokens = 128000;
const special_tokens = {
"<|begin_of_text|>": 128000,
"<|end_of_text|>": 128001,
"<|start_header_id|>": 128006,
"<|end_header_id|>": 128007,
"<|eot_id|>": 128009
};
const model = await window.tiktokenLoad({
"load_tiktoken_bpe": bpeUrl,
"special_tokens": special_tokens,
"pat_str": PAT_STR
});
const tokenizer = new window.Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str)
return {
get bos_id() {
return special_tokens["<|begin_of_text|>"];
},
get stop_tokens() {
return new Set([
special_tokens["<|end_of_text|>"],
special_tokens["<|eot_id|>"],
]);
},
decode(toks) {
const filtered = toks.filter((t) => t < num_base_tokens);
return tokenizer.decode(filtered);
},
encode(text, allow_special = false) {
const allowedSpecial = allow_special ? "all" : new Set();
const disallowedSpecial = new Set();
return tokenizer.encode(text, allowedSpecial, disallowedSpecial);
},
encodeRole(role) {
const tokens = [];
tokens.push(special_tokens["<|start_header_id|>"]);
tokens.push(...this.encode(role));
tokens.push(special_tokens["<|end_header_id|>"]);
tokens.push(...this.encode("\n\n"));
return tokens;
},
encodeMessage(role, content) {
const roleTokens = this.encodeRole(role);
const contentTokens = this.encode(content.trim());
return [...roleTokens, ...contentTokens, special_tokens["<|eot_id|>"]];
},
};
}

View File

@@ -0,0 +1,11 @@
#!/usr/bin/env bash
cd "$(dirname "$0")"
npm init -y && \
npm install --save-dev webpack webpack-cli && \
npm install tiktoken && \
jq '.scripts.build = "webpack"' package.json > package.tmp.json && \
mv package.tmp.json package.json && \
npm run build && \
mv dist/*.wasm ./tiktoken_bg.wasm && \
mv dist/* ./ && \
rm -rf dist node_modules package-lock.json package.json

View File

@@ -0,0 +1,5 @@
// Force Webpack to copy the WASM
import 'tiktoken/tiktoken_bg.wasm';
import { init, get_encoding, encoding_for_model, Tiktoken } from 'tiktoken/init';
import { load } from 'tiktoken/load';
export { init, get_encoding, encoding_for_model, Tiktoken, load };

View File

@@ -0,0 +1,25 @@
const path = require("path");
module.exports = {
mode: "production",
entry: "./tiktoken-export.js",
output: {
filename: "tiktoken.js",
path: path.resolve(__dirname, "dist"),
library: {
type: "module"
}
},
experiments: {
outputModule: true,
asyncWebAssembly: true
},
module: {
rules: [
{
test: /\.wasm$/,
type: "asset/resource",
}
]
}
};

View File

@@ -0,0 +1,62 @@
const kernelsReady = (async () => {
// can't get browser to use updated versions except with cache-busting query string
const exports = await import(`./net_clang.js?version=${Date.now()}`);
Object.assign(self, exports);
})();
async function init(event) {
await kernelsReady;
self.model = await self.transformer();
self.addEventListener("message", loadStateDict);
self.removeEventListener("message", init);
self.postMessage("success");
}
function loadStateDict(event) {
if (event.data === "done") {
self.addEventListener("message", inference);
self.removeEventListener("message", loadStateDict);
}
else {
if (event.data.length > 1) {
// the bytes from files are set contiguously in WASM memory
const malloc_size = event.data.reduce((sum, file) => sum + file.bytes.length, 0);
const malloc_ptr = self.model.wasm._malloc(malloc_size);
let cursor = 0;
for (const file of event.data) {
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr + cursor);
for (const part of file.parts) {
if (part.target_start_pos === 0) {
// tell WASM code where the tensor is in memory
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + cursor);
}
cursor += part.size;
}
file.bytes = null;
}
}
else {
// the bytes from files are not guaranteed to be set contiguously in WASM memory
const file = event.data[0];
const malloc_ptr = self.model.wasm._malloc(file.size);
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr);
for (const part of file.parts) {
if (part.target_start_pos === 0) {
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + part.file_start_pos);
}
}
file.bytes = null;
}
}
self.postMessage("success");
}
function inference(event) {
const [tok, start_pos] = event.data;
const int32tok = new Int32Array([tok]);
const model_out = self.model.run(new Uint8Array(int32tok.buffer), start_pos);
const int32nextTok = new Int32Array(model_out[0].buffer);
self.postMessage(int32nextTok[0]);
}
self.addEventListener("message", init);

View File

@@ -0,0 +1,209 @@
/*
* Copyright 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE COPYRIGHT HOLDER(S) OR AUTHOR(S) BE LIABLE FOR ANY CLAIM, DAMAGES OR
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*
*/
#ifndef _hdp_6_0_0_OFFSET_HEADER
#define _hdp_6_0_0_OFFSET_HEADER
// addressBlock: hdp_hdpdec
// base address: 0x3c80
#define regHDP_NONSURFACE_BASE 0x0040
#define regHDP_NONSURFACE_BASE_BASE_IDX 0
#define regHDP_NONSURFACE_INFO 0x0041
#define regHDP_NONSURFACE_INFO_BASE_IDX 0
#define regHDP_NONSURFACE_BASE_HI 0x0042
#define regHDP_NONSURFACE_BASE_HI_BASE_IDX 0
#define regHDP_SURFACE_WRITE_FLAGS 0x00c4
#define regHDP_SURFACE_WRITE_FLAGS_BASE_IDX 0
#define regHDP_SURFACE_READ_FLAGS 0x00c5
#define regHDP_SURFACE_READ_FLAGS_BASE_IDX 0
#define regHDP_SURFACE_WRITE_FLAGS_CLR 0x00c6
#define regHDP_SURFACE_WRITE_FLAGS_CLR_BASE_IDX 0
#define regHDP_SURFACE_READ_FLAGS_CLR 0x00c7
#define regHDP_SURFACE_READ_FLAGS_CLR_BASE_IDX 0
#define regHDP_NONSURF_FLAGS 0x00c8
#define regHDP_NONSURF_FLAGS_BASE_IDX 0
#define regHDP_NONSURF_FLAGS_CLR 0x00c9
#define regHDP_NONSURF_FLAGS_CLR_BASE_IDX 0
#define regHDP_HOST_PATH_CNTL 0x00cc
#define regHDP_HOST_PATH_CNTL_BASE_IDX 0
#define regHDP_SW_SEMAPHORE 0x00cd
#define regHDP_SW_SEMAPHORE_BASE_IDX 0
#define regHDP_DEBUG0 0x00ce
#define regHDP_DEBUG0_BASE_IDX 0
#define regHDP_LAST_SURFACE_HIT 0x00d0
#define regHDP_LAST_SURFACE_HIT_BASE_IDX 0
#define regHDP_OUTSTANDING_REQ 0x00d2
#define regHDP_OUTSTANDING_REQ_BASE_IDX 0
#define regHDP_MISC_CNTL 0x00d3
#define regHDP_MISC_CNTL_BASE_IDX 0
#define regHDP_MEM_POWER_CTRL 0x00d4
#define regHDP_MEM_POWER_CTRL_BASE_IDX 0
#define regHDP_MMHUB_CNTL 0x00d5
#define regHDP_MMHUB_CNTL_BASE_IDX 0
#define regHDP_VERSION 0x00d7
#define regHDP_VERSION_BASE_IDX 0
#define regHDP_CLK_CNTL 0x00d8
#define regHDP_CLK_CNTL_BASE_IDX 0
#define regHDP_MEMIO_CNTL 0x00f6
#define regHDP_MEMIO_CNTL_BASE_IDX 0
#define regHDP_MEMIO_ADDR 0x00f7
#define regHDP_MEMIO_ADDR_BASE_IDX 0
#define regHDP_MEMIO_STATUS 0x00f8
#define regHDP_MEMIO_STATUS_BASE_IDX 0
#define regHDP_MEMIO_WR_DATA 0x00f9
#define regHDP_MEMIO_WR_DATA_BASE_IDX 0
#define regHDP_MEMIO_RD_DATA 0x00fa
#define regHDP_MEMIO_RD_DATA_BASE_IDX 0
#define regHDP_XDP_DIRECT2HDP_FIRST 0x0100
#define regHDP_XDP_DIRECT2HDP_FIRST_BASE_IDX 0
#define regHDP_XDP_D2H_FLUSH 0x0101
#define regHDP_XDP_D2H_FLUSH_BASE_IDX 0
#define regHDP_XDP_D2H_BAR_UPDATE 0x0102
#define regHDP_XDP_D2H_BAR_UPDATE_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_3 0x0103
#define regHDP_XDP_D2H_RSVD_3_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_4 0x0104
#define regHDP_XDP_D2H_RSVD_4_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_5 0x0105
#define regHDP_XDP_D2H_RSVD_5_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_6 0x0106
#define regHDP_XDP_D2H_RSVD_6_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_7 0x0107
#define regHDP_XDP_D2H_RSVD_7_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_8 0x0108
#define regHDP_XDP_D2H_RSVD_8_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_9 0x0109
#define regHDP_XDP_D2H_RSVD_9_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_10 0x010a
#define regHDP_XDP_D2H_RSVD_10_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_11 0x010b
#define regHDP_XDP_D2H_RSVD_11_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_12 0x010c
#define regHDP_XDP_D2H_RSVD_12_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_13 0x010d
#define regHDP_XDP_D2H_RSVD_13_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_14 0x010e
#define regHDP_XDP_D2H_RSVD_14_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_15 0x010f
#define regHDP_XDP_D2H_RSVD_15_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_16 0x0110
#define regHDP_XDP_D2H_RSVD_16_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_17 0x0111
#define regHDP_XDP_D2H_RSVD_17_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_18 0x0112
#define regHDP_XDP_D2H_RSVD_18_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_19 0x0113
#define regHDP_XDP_D2H_RSVD_19_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_20 0x0114
#define regHDP_XDP_D2H_RSVD_20_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_21 0x0115
#define regHDP_XDP_D2H_RSVD_21_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_22 0x0116
#define regHDP_XDP_D2H_RSVD_22_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_23 0x0117
#define regHDP_XDP_D2H_RSVD_23_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_24 0x0118
#define regHDP_XDP_D2H_RSVD_24_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_25 0x0119
#define regHDP_XDP_D2H_RSVD_25_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_26 0x011a
#define regHDP_XDP_D2H_RSVD_26_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_27 0x011b
#define regHDP_XDP_D2H_RSVD_27_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_28 0x011c
#define regHDP_XDP_D2H_RSVD_28_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_29 0x011d
#define regHDP_XDP_D2H_RSVD_29_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_30 0x011e
#define regHDP_XDP_D2H_RSVD_30_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_31 0x011f
#define regHDP_XDP_D2H_RSVD_31_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_32 0x0120
#define regHDP_XDP_D2H_RSVD_32_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_33 0x0121
#define regHDP_XDP_D2H_RSVD_33_BASE_IDX 0
#define regHDP_XDP_D2H_RSVD_34 0x0122
#define regHDP_XDP_D2H_RSVD_34_BASE_IDX 0
#define regHDP_XDP_DIRECT2HDP_LAST 0x0123
#define regHDP_XDP_DIRECT2HDP_LAST_BASE_IDX 0
#define regHDP_XDP_P2P_BAR_CFG 0x0124
#define regHDP_XDP_P2P_BAR_CFG_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_OFFSET 0x0125
#define regHDP_XDP_P2P_MBX_OFFSET_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR0 0x0126
#define regHDP_XDP_P2P_MBX_ADDR0_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR1 0x0127
#define regHDP_XDP_P2P_MBX_ADDR1_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR2 0x0128
#define regHDP_XDP_P2P_MBX_ADDR2_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR3 0x0129
#define regHDP_XDP_P2P_MBX_ADDR3_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR4 0x012a
#define regHDP_XDP_P2P_MBX_ADDR4_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR5 0x012b
#define regHDP_XDP_P2P_MBX_ADDR5_BASE_IDX 0
#define regHDP_XDP_P2P_MBX_ADDR6 0x012c
#define regHDP_XDP_P2P_MBX_ADDR6_BASE_IDX 0
#define regHDP_XDP_HDP_MBX_MC_CFG 0x012d
#define regHDP_XDP_HDP_MBX_MC_CFG_BASE_IDX 0
#define regHDP_XDP_HDP_MC_CFG 0x012e
#define regHDP_XDP_HDP_MC_CFG_BASE_IDX 0
#define regHDP_XDP_HST_CFG 0x012f
#define regHDP_XDP_HST_CFG_BASE_IDX 0
#define regHDP_XDP_HDP_IPH_CFG 0x0131
#define regHDP_XDP_HDP_IPH_CFG_BASE_IDX 0
#define regHDP_XDP_P2P_BAR0 0x0134
#define regHDP_XDP_P2P_BAR0_BASE_IDX 0
#define regHDP_XDP_P2P_BAR1 0x0135
#define regHDP_XDP_P2P_BAR1_BASE_IDX 0
#define regHDP_XDP_P2P_BAR2 0x0136
#define regHDP_XDP_P2P_BAR2_BASE_IDX 0
#define regHDP_XDP_P2P_BAR3 0x0137
#define regHDP_XDP_P2P_BAR3_BASE_IDX 0
#define regHDP_XDP_P2P_BAR4 0x0138
#define regHDP_XDP_P2P_BAR4_BASE_IDX 0
#define regHDP_XDP_P2P_BAR5 0x0139
#define regHDP_XDP_P2P_BAR5_BASE_IDX 0
#define regHDP_XDP_P2P_BAR6 0x013a
#define regHDP_XDP_P2P_BAR6_BASE_IDX 0
#define regHDP_XDP_P2P_BAR7 0x013b
#define regHDP_XDP_P2P_BAR7_BASE_IDX 0
#define regHDP_XDP_FLUSH_ARMED_STS 0x013c
#define regHDP_XDP_FLUSH_ARMED_STS_BASE_IDX 0
#define regHDP_XDP_FLUSH_CNTR0_STS 0x013d
#define regHDP_XDP_FLUSH_CNTR0_STS_BASE_IDX 0
#define regHDP_XDP_BUSY_STS 0x013e
#define regHDP_XDP_BUSY_STS_BASE_IDX 0
#define regHDP_XDP_STICKY 0x013f
#define regHDP_XDP_STICKY_BASE_IDX 0
#define regHDP_XDP_CHKN 0x0140
#define regHDP_XDP_CHKN_BASE_IDX 0
#define regHDP_XDP_BARS_ADDR_39_36 0x0144
#define regHDP_XDP_BARS_ADDR_39_36_BASE_IDX 0
#define regHDP_XDP_MC_VM_FB_LOCATION_BASE 0x0145
#define regHDP_XDP_MC_VM_FB_LOCATION_BASE_BASE_IDX 0
#define regHDP_XDP_MMHUB_ERROR 0x014a
#define regHDP_XDP_MMHUB_ERROR_BASE_IDX 0
#endif

View File

@@ -0,0 +1,646 @@
/*
* Copyright 2021 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE COPYRIGHT HOLDER(S) OR AUTHOR(S) BE LIABLE FOR ANY CLAIM, DAMAGES OR
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*
*/
#ifndef _hdp_6_0_0_SH_MASK_HEADER
#define _hdp_6_0_0_SH_MASK_HEADER
// addressBlock: hdp_hdpdec
//HDP_NONSURFACE_BASE
#define HDP_NONSURFACE_BASE__NONSURF_BASE_39_8__SHIFT 0x0
#define HDP_NONSURFACE_BASE__NONSURF_BASE_39_8_MASK 0xFFFFFFFFL
//HDP_NONSURFACE_INFO
#define HDP_NONSURFACE_INFO__NONSURF_SWAP__SHIFT 0x4
#define HDP_NONSURFACE_INFO__NONSURF_VMID__SHIFT 0x8
#define HDP_NONSURFACE_INFO__NONSURF_SWAP_MASK 0x00000030L
#define HDP_NONSURFACE_INFO__NONSURF_VMID_MASK 0x00000F00L
//HDP_NONSURFACE_BASE_HI
#define HDP_NONSURFACE_BASE_HI__NONSURF_BASE_47_40__SHIFT 0x0
#define HDP_NONSURFACE_BASE_HI__NONSURF_BASE_47_40_MASK 0x000000FFL
//HDP_SURFACE_WRITE_FLAGS
#define HDP_SURFACE_WRITE_FLAGS__SURF0_WRITE_FLAG__SHIFT 0x0
#define HDP_SURFACE_WRITE_FLAGS__SURF1_WRITE_FLAG__SHIFT 0x1
#define HDP_SURFACE_WRITE_FLAGS__SURF0_WRITE_FLAG_MASK 0x00000001L
#define HDP_SURFACE_WRITE_FLAGS__SURF1_WRITE_FLAG_MASK 0x00000002L
//HDP_SURFACE_READ_FLAGS
#define HDP_SURFACE_READ_FLAGS__SURF0_READ_FLAG__SHIFT 0x0
#define HDP_SURFACE_READ_FLAGS__SURF1_READ_FLAG__SHIFT 0x1
#define HDP_SURFACE_READ_FLAGS__SURF0_READ_FLAG_MASK 0x00000001L
#define HDP_SURFACE_READ_FLAGS__SURF1_READ_FLAG_MASK 0x00000002L
//HDP_SURFACE_WRITE_FLAGS_CLR
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF0_WRITE_FLAG_CLR__SHIFT 0x0
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF1_WRITE_FLAG_CLR__SHIFT 0x1
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF0_WRITE_FLAG_CLR_MASK 0x00000001L
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF1_WRITE_FLAG_CLR_MASK 0x00000002L
//HDP_SURFACE_READ_FLAGS_CLR
#define HDP_SURFACE_READ_FLAGS_CLR__SURF0_READ_FLAG_CLR__SHIFT 0x0
#define HDP_SURFACE_READ_FLAGS_CLR__SURF1_READ_FLAG_CLR__SHIFT 0x1
#define HDP_SURFACE_READ_FLAGS_CLR__SURF0_READ_FLAG_CLR_MASK 0x00000001L
#define HDP_SURFACE_READ_FLAGS_CLR__SURF1_READ_FLAG_CLR_MASK 0x00000002L
//HDP_NONSURF_FLAGS
#define HDP_NONSURF_FLAGS__NONSURF_WRITE_FLAG__SHIFT 0x0
#define HDP_NONSURF_FLAGS__NONSURF_READ_FLAG__SHIFT 0x1
#define HDP_NONSURF_FLAGS__NONSURF_WRITE_FLAG_MASK 0x00000001L
#define HDP_NONSURF_FLAGS__NONSURF_READ_FLAG_MASK 0x00000002L
//HDP_NONSURF_FLAGS_CLR
#define HDP_NONSURF_FLAGS_CLR__NONSURF_WRITE_FLAG_CLR__SHIFT 0x0
#define HDP_NONSURF_FLAGS_CLR__NONSURF_READ_FLAG_CLR__SHIFT 0x1
#define HDP_NONSURF_FLAGS_CLR__NONSURF_WRITE_FLAG_CLR_MASK 0x00000001L
#define HDP_NONSURF_FLAGS_CLR__NONSURF_READ_FLAG_CLR_MASK 0x00000002L
//HDP_HOST_PATH_CNTL
#define HDP_HOST_PATH_CNTL__WR_STALL_TIMER__SHIFT 0x9
#define HDP_HOST_PATH_CNTL__RD_STALL_TIMER__SHIFT 0xb
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_PRELOAD_CFG__SHIFT 0x12
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER__SHIFT 0x13
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_EN__SHIFT 0x15
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_64B_EN__SHIFT 0x16
#define HDP_HOST_PATH_CNTL__ALL_SURFACES_DIS__SHIFT 0x1d
#define HDP_HOST_PATH_CNTL__WR_STALL_TIMER_MASK 0x00000600L
#define HDP_HOST_PATH_CNTL__RD_STALL_TIMER_MASK 0x00001800L
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_PRELOAD_CFG_MASK 0x00040000L
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_MASK 0x00180000L
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_EN_MASK 0x00200000L
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_64B_EN_MASK 0x00400000L
#define HDP_HOST_PATH_CNTL__ALL_SURFACES_DIS_MASK 0x20000000L
//HDP_SW_SEMAPHORE
#define HDP_SW_SEMAPHORE__SW_SEMAPHORE__SHIFT 0x0
#define HDP_SW_SEMAPHORE__SW_SEMAPHORE_MASK 0xFFFFFFFFL
//HDP_DEBUG0
#define HDP_DEBUG0__HDP_DEBUG__SHIFT 0x0
#define HDP_DEBUG0__HDP_DEBUG_MASK 0xFFFFFFFFL
//HDP_LAST_SURFACE_HIT
#define HDP_LAST_SURFACE_HIT__LAST_SURFACE_HIT__SHIFT 0x0
#define HDP_LAST_SURFACE_HIT__LAST_SURFACE_HIT_MASK 0x00000003L
//HDP_OUTSTANDING_REQ
#define HDP_OUTSTANDING_REQ__WRITE_REQ__SHIFT 0x0
#define HDP_OUTSTANDING_REQ__READ_REQ__SHIFT 0x8
#define HDP_OUTSTANDING_REQ__WRITE_REQ_MASK 0x000000FFL
#define HDP_OUTSTANDING_REQ__READ_REQ_MASK 0x0000FF00L
//HDP_MISC_CNTL
#define HDP_MISC_CNTL__IDLE_HYSTERESIS_CNTL__SHIFT 0x2
#define HDP_MISC_CNTL__OUTSTANDING_WRITE_COUNT_1024__SHIFT 0x5
#define HDP_MISC_CNTL__MMHUB_EARLY_WRACK_ENABLE__SHIFT 0x8
#define HDP_MISC_CNTL__SIMULTANEOUS_READS_WRITES__SHIFT 0xb
#define HDP_MISC_CNTL__READ_BUFFER_WATERMARK__SHIFT 0xe
#define HDP_MISC_CNTL__NACK_ENABLE__SHIFT 0x13
#define HDP_MISC_CNTL__ATOMIC_NACK_ENABLE__SHIFT 0x14
#define HDP_MISC_CNTL__FED_ENABLE__SHIFT 0x15
#define HDP_MISC_CNTL__ATOMIC_FED_ENABLE__SHIFT 0x16
#define HDP_MISC_CNTL__MMHUB_WRBURST_ENABLE__SHIFT 0x18
#define HDP_MISC_CNTL__MMHUB_WRBURST_SIZE__SHIFT 0x1e
#define HDP_MISC_CNTL__IDLE_HYSTERESIS_CNTL_MASK 0x0000000CL
#define HDP_MISC_CNTL__OUTSTANDING_WRITE_COUNT_1024_MASK 0x00000020L
#define HDP_MISC_CNTL__MMHUB_EARLY_WRACK_ENABLE_MASK 0x00000100L
#define HDP_MISC_CNTL__SIMULTANEOUS_READS_WRITES_MASK 0x00000800L
#define HDP_MISC_CNTL__READ_BUFFER_WATERMARK_MASK 0x0000C000L
#define HDP_MISC_CNTL__NACK_ENABLE_MASK 0x00080000L
#define HDP_MISC_CNTL__ATOMIC_NACK_ENABLE_MASK 0x00100000L
#define HDP_MISC_CNTL__FED_ENABLE_MASK 0x00200000L
#define HDP_MISC_CNTL__ATOMIC_FED_ENABLE_MASK 0x00400000L
#define HDP_MISC_CNTL__MMHUB_WRBURST_ENABLE_MASK 0x01000000L
#define HDP_MISC_CNTL__MMHUB_WRBURST_SIZE_MASK 0x40000000L
//HDP_MEM_POWER_CTRL
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_CTRL_EN__SHIFT 0x0
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_LS_EN__SHIFT 0x1
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DS_EN__SHIFT 0x2
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_SD_EN__SHIFT 0x3
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_IDLE_HYSTERESIS__SHIFT 0x4
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_UP_RECOVER_DELAY__SHIFT 0x8
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DOWN_ENTER_DELAY__SHIFT 0xe
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_CTRL_EN__SHIFT 0x10
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_LS_EN__SHIFT 0x11
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DS_EN__SHIFT 0x12
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_SD_EN__SHIFT 0x13
#define HDP_MEM_POWER_CTRL__RC_MEM_IDLE_HYSTERESIS__SHIFT 0x14
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_UP_RECOVER_DELAY__SHIFT 0x18
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DOWN_ENTER_DELAY__SHIFT 0x1e
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_CTRL_EN_MASK 0x00000001L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_LS_EN_MASK 0x00000002L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DS_EN_MASK 0x00000004L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_SD_EN_MASK 0x00000008L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_IDLE_HYSTERESIS_MASK 0x00000070L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_UP_RECOVER_DELAY_MASK 0x00003F00L
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DOWN_ENTER_DELAY_MASK 0x0000C000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_CTRL_EN_MASK 0x00010000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_LS_EN_MASK 0x00020000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DS_EN_MASK 0x00040000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_SD_EN_MASK 0x00080000L
#define HDP_MEM_POWER_CTRL__RC_MEM_IDLE_HYSTERESIS_MASK 0x00700000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_UP_RECOVER_DELAY_MASK 0x3F000000L
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DOWN_ENTER_DELAY_MASK 0xC0000000L
//HDP_MMHUB_CNTL
#define HDP_MMHUB_CNTL__HDP_MMHUB_RO__SHIFT 0x0
#define HDP_MMHUB_CNTL__HDP_MMHUB_GCC__SHIFT 0x1
#define HDP_MMHUB_CNTL__HDP_MMHUB_SNOOP__SHIFT 0x2
#define HDP_MMHUB_CNTL__HDP_MMHUB_RO_MASK 0x00000001L
#define HDP_MMHUB_CNTL__HDP_MMHUB_GCC_MASK 0x00000002L
#define HDP_MMHUB_CNTL__HDP_MMHUB_SNOOP_MASK 0x00000004L
//HDP_VERSION
#define HDP_VERSION__MINVER__SHIFT 0x0
#define HDP_VERSION__MAJVER__SHIFT 0x8
#define HDP_VERSION__REV__SHIFT 0x10
#define HDP_VERSION__MINVER_MASK 0x000000FFL
#define HDP_VERSION__MAJVER_MASK 0x0000FF00L
#define HDP_VERSION__REV_MASK 0x00FF0000L
//HDP_CLK_CNTL
#define HDP_CLK_CNTL__REG_CLK_ENABLE_COUNT__SHIFT 0x0
#define HDP_CLK_CNTL__ATOMIC_MEM_CLK_SOFT_OVERRIDE__SHIFT 0x1a
#define HDP_CLK_CNTL__RC_MEM_CLK_SOFT_OVERRIDE__SHIFT 0x1b
#define HDP_CLK_CNTL__DBUS_CLK_SOFT_OVERRIDE__SHIFT 0x1c
#define HDP_CLK_CNTL__DYN_CLK_SOFT_OVERRIDE__SHIFT 0x1d
#define HDP_CLK_CNTL__XDP_REG_CLK_SOFT_OVERRIDE__SHIFT 0x1e
#define HDP_CLK_CNTL__HDP_REG_CLK_SOFT_OVERRIDE__SHIFT 0x1f
#define HDP_CLK_CNTL__REG_CLK_ENABLE_COUNT_MASK 0x0000000FL
#define HDP_CLK_CNTL__ATOMIC_MEM_CLK_SOFT_OVERRIDE_MASK 0x04000000L
#define HDP_CLK_CNTL__RC_MEM_CLK_SOFT_OVERRIDE_MASK 0x08000000L
#define HDP_CLK_CNTL__DBUS_CLK_SOFT_OVERRIDE_MASK 0x10000000L
#define HDP_CLK_CNTL__DYN_CLK_SOFT_OVERRIDE_MASK 0x20000000L
#define HDP_CLK_CNTL__XDP_REG_CLK_SOFT_OVERRIDE_MASK 0x40000000L
#define HDP_CLK_CNTL__HDP_REG_CLK_SOFT_OVERRIDE_MASK 0x80000000L
//HDP_MEMIO_CNTL
#define HDP_MEMIO_CNTL__MEMIO_SEND__SHIFT 0x0
#define HDP_MEMIO_CNTL__MEMIO_OP__SHIFT 0x1
#define HDP_MEMIO_CNTL__MEMIO_BE__SHIFT 0x2
#define HDP_MEMIO_CNTL__MEMIO_WR_STROBE__SHIFT 0x6
#define HDP_MEMIO_CNTL__MEMIO_RD_STROBE__SHIFT 0x7
#define HDP_MEMIO_CNTL__MEMIO_ADDR_UPPER__SHIFT 0x8
#define HDP_MEMIO_CNTL__MEMIO_CLR_WR_ERROR__SHIFT 0xe
#define HDP_MEMIO_CNTL__MEMIO_CLR_RD_ERROR__SHIFT 0xf
#define HDP_MEMIO_CNTL__MEMIO_VF__SHIFT 0x10
#define HDP_MEMIO_CNTL__MEMIO_VFID__SHIFT 0x11
#define HDP_MEMIO_CNTL__MEMIO_SEND_MASK 0x00000001L
#define HDP_MEMIO_CNTL__MEMIO_OP_MASK 0x00000002L
#define HDP_MEMIO_CNTL__MEMIO_BE_MASK 0x0000003CL
#define HDP_MEMIO_CNTL__MEMIO_WR_STROBE_MASK 0x00000040L
#define HDP_MEMIO_CNTL__MEMIO_RD_STROBE_MASK 0x00000080L
#define HDP_MEMIO_CNTL__MEMIO_ADDR_UPPER_MASK 0x00003F00L
#define HDP_MEMIO_CNTL__MEMIO_CLR_WR_ERROR_MASK 0x00004000L
#define HDP_MEMIO_CNTL__MEMIO_CLR_RD_ERROR_MASK 0x00008000L
#define HDP_MEMIO_CNTL__MEMIO_VF_MASK 0x00010000L
#define HDP_MEMIO_CNTL__MEMIO_VFID_MASK 0x003E0000L
//HDP_MEMIO_ADDR
#define HDP_MEMIO_ADDR__MEMIO_ADDR_LOWER__SHIFT 0x0
#define HDP_MEMIO_ADDR__MEMIO_ADDR_LOWER_MASK 0xFFFFFFFFL
//HDP_MEMIO_STATUS
#define HDP_MEMIO_STATUS__MEMIO_WR_STATUS__SHIFT 0x0
#define HDP_MEMIO_STATUS__MEMIO_RD_STATUS__SHIFT 0x1
#define HDP_MEMIO_STATUS__MEMIO_WR_ERROR__SHIFT 0x2
#define HDP_MEMIO_STATUS__MEMIO_RD_ERROR__SHIFT 0x3
#define HDP_MEMIO_STATUS__MEMIO_WR_STATUS_MASK 0x00000001L
#define HDP_MEMIO_STATUS__MEMIO_RD_STATUS_MASK 0x00000002L
#define HDP_MEMIO_STATUS__MEMIO_WR_ERROR_MASK 0x00000004L
#define HDP_MEMIO_STATUS__MEMIO_RD_ERROR_MASK 0x00000008L
//HDP_MEMIO_WR_DATA
#define HDP_MEMIO_WR_DATA__MEMIO_WR_DATA__SHIFT 0x0
#define HDP_MEMIO_WR_DATA__MEMIO_WR_DATA_MASK 0xFFFFFFFFL
//HDP_MEMIO_RD_DATA
#define HDP_MEMIO_RD_DATA__MEMIO_RD_DATA__SHIFT 0x0
#define HDP_MEMIO_RD_DATA__MEMIO_RD_DATA_MASK 0xFFFFFFFFL
//HDP_XDP_DIRECT2HDP_FIRST
#define HDP_XDP_DIRECT2HDP_FIRST__RESERVED__SHIFT 0x0
#define HDP_XDP_DIRECT2HDP_FIRST__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_FLUSH
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_FLUSH_NUM__SHIFT 0x0
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ENC_DATA__SHIFT 0x4
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ADDR_SEL__SHIFT 0x8
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_XPB_CLG__SHIFT 0xb
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_SEND_HOST__SHIFT 0x10
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_ALTER_FLUSH_NUM__SHIFT 0x12
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_0__SHIFT 0x13
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_1__SHIFT 0x14
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_FLUSH_NUM_MASK 0x0000000FL
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ENC_DATA_MASK 0x000000F0L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ADDR_SEL_MASK 0x00000700L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_XPB_CLG_MASK 0x0000F800L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_SEND_HOST_MASK 0x00010000L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_ALTER_FLUSH_NUM_MASK 0x00040000L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_0_MASK 0x00080000L
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_1_MASK 0x00100000L
//HDP_XDP_D2H_BAR_UPDATE
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_ADDR__SHIFT 0x0
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_FLUSH_NUM__SHIFT 0x10
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_BAR_NUM__SHIFT 0x14
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_ADDR_MASK 0x0000FFFFL
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_FLUSH_NUM_MASK 0x000F0000L
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_BAR_NUM_MASK 0x00700000L
//HDP_XDP_D2H_RSVD_3
#define HDP_XDP_D2H_RSVD_3__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_3__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_4
#define HDP_XDP_D2H_RSVD_4__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_4__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_5
#define HDP_XDP_D2H_RSVD_5__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_5__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_6
#define HDP_XDP_D2H_RSVD_6__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_6__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_7
#define HDP_XDP_D2H_RSVD_7__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_7__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_8
#define HDP_XDP_D2H_RSVD_8__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_8__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_9
#define HDP_XDP_D2H_RSVD_9__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_9__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_10
#define HDP_XDP_D2H_RSVD_10__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_10__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_11
#define HDP_XDP_D2H_RSVD_11__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_11__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_12
#define HDP_XDP_D2H_RSVD_12__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_12__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_13
#define HDP_XDP_D2H_RSVD_13__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_13__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_14
#define HDP_XDP_D2H_RSVD_14__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_14__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_15
#define HDP_XDP_D2H_RSVD_15__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_15__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_16
#define HDP_XDP_D2H_RSVD_16__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_16__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_17
#define HDP_XDP_D2H_RSVD_17__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_17__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_18
#define HDP_XDP_D2H_RSVD_18__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_18__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_19
#define HDP_XDP_D2H_RSVD_19__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_19__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_20
#define HDP_XDP_D2H_RSVD_20__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_20__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_21
#define HDP_XDP_D2H_RSVD_21__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_21__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_22
#define HDP_XDP_D2H_RSVD_22__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_22__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_23
#define HDP_XDP_D2H_RSVD_23__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_23__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_24
#define HDP_XDP_D2H_RSVD_24__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_24__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_25
#define HDP_XDP_D2H_RSVD_25__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_25__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_26
#define HDP_XDP_D2H_RSVD_26__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_26__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_27
#define HDP_XDP_D2H_RSVD_27__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_27__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_28
#define HDP_XDP_D2H_RSVD_28__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_28__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_29
#define HDP_XDP_D2H_RSVD_29__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_29__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_30
#define HDP_XDP_D2H_RSVD_30__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_30__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_31
#define HDP_XDP_D2H_RSVD_31__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_31__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_32
#define HDP_XDP_D2H_RSVD_32__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_32__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_33
#define HDP_XDP_D2H_RSVD_33__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_33__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_D2H_RSVD_34
#define HDP_XDP_D2H_RSVD_34__RESERVED__SHIFT 0x0
#define HDP_XDP_D2H_RSVD_34__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_DIRECT2HDP_LAST
#define HDP_XDP_DIRECT2HDP_LAST__RESERVED__SHIFT 0x0
#define HDP_XDP_DIRECT2HDP_LAST__RESERVED_MASK 0xFFFFFFFFL
//HDP_XDP_P2P_BAR_CFG
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_ADDR_SIZE__SHIFT 0x0
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_BAR_FROM__SHIFT 0x4
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_ADDR_SIZE_MASK 0x0000000FL
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_BAR_FROM_MASK 0x00000030L
//HDP_XDP_P2P_MBX_OFFSET
#define HDP_XDP_P2P_MBX_OFFSET__P2P_MBX_OFFSET__SHIFT 0x0
#define HDP_XDP_P2P_MBX_OFFSET__P2P_MBX_OFFSET_MASK 0x0001FFFFL
//HDP_XDP_P2P_MBX_ADDR0
#define HDP_XDP_P2P_MBX_ADDR0__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR0__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR1
#define HDP_XDP_P2P_MBX_ADDR1__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR1__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR2
#define HDP_XDP_P2P_MBX_ADDR2__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR2__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR3
#define HDP_XDP_P2P_MBX_ADDR3__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR3__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR4
#define HDP_XDP_P2P_MBX_ADDR4__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR4__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR5
#define HDP_XDP_P2P_MBX_ADDR5__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR5__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_P2P_MBX_ADDR6
#define HDP_XDP_P2P_MBX_ADDR6__VALID__SHIFT 0x0
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_35_19__SHIFT 0x3
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_39_36__SHIFT 0x14
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_47_40__SHIFT 0x18
#define HDP_XDP_P2P_MBX_ADDR6__VALID_MASK 0x00000001L
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_35_19_MASK 0x000FFFF8L
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_47_40_MASK 0xFF000000L
//HDP_XDP_HDP_MBX_MC_CFG
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_QOS__SHIFT 0x0
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SWAP__SHIFT 0x4
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_VMID__SHIFT 0x8
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_RO__SHIFT 0xc
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_GCC__SHIFT 0xd
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SNOOP__SHIFT 0xe
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_QOS_MASK 0x0000000FL
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SWAP_MASK 0x00000030L
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_VMID_MASK 0x00000F00L
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_RO_MASK 0x00001000L
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_GCC_MASK 0x00002000L
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SNOOP_MASK 0x00004000L
//HDP_XDP_HDP_MC_CFG
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SNOOP__SHIFT 0x3
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SWAP__SHIFT 0x4
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_VMID__SHIFT 0x8
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_RO__SHIFT 0xc
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_GCC__SHIFT 0xd
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_XDP_HIGHER_PRI_THRESH__SHIFT 0xe
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SNOOP_MASK 0x00000008L
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SWAP_MASK 0x00000030L
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_VMID_MASK 0x00000F00L
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_RO_MASK 0x00001000L
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_GCC_MASK 0x00002000L
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_XDP_HIGHER_PRI_THRESH_MASK 0x000FC000L
//HDP_XDP_HST_CFG
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_EN__SHIFT 0x0
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER__SHIFT 0x1
#define HDP_XDP_HST_CFG__HST_CFG_WR_BURST_EN__SHIFT 0x3
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_64B_EN__SHIFT 0x4
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_PRELOAD_CFG__SHIFT 0x5
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_EN_MASK 0x00000001L
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_MASK 0x00000006L
#define HDP_XDP_HST_CFG__HST_CFG_WR_BURST_EN_MASK 0x00000008L
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_64B_EN_MASK 0x00000010L
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_PRELOAD_CFG_MASK 0x00000020L
//HDP_XDP_HDP_IPH_CFG
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_INVERSE_PEER_TAG_MATCHING__SHIFT 0xc
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_P2P_RD_EN__SHIFT 0xd
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_INVERSE_PEER_TAG_MATCHING_MASK 0x00001000L
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_P2P_RD_EN_MASK 0x00002000L
//HDP_XDP_P2P_BAR0
#define HDP_XDP_P2P_BAR0__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR0__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR0__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR0__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR0__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR0__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR1
#define HDP_XDP_P2P_BAR1__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR1__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR1__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR1__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR1__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR1__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR2
#define HDP_XDP_P2P_BAR2__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR2__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR2__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR2__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR2__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR2__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR3
#define HDP_XDP_P2P_BAR3__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR3__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR3__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR3__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR3__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR3__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR4
#define HDP_XDP_P2P_BAR4__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR4__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR4__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR4__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR4__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR4__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR5
#define HDP_XDP_P2P_BAR5__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR5__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR5__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR5__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR5__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR5__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR6
#define HDP_XDP_P2P_BAR6__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR6__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR6__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR6__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR6__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR6__VALID_MASK 0x00100000L
//HDP_XDP_P2P_BAR7
#define HDP_XDP_P2P_BAR7__ADDR__SHIFT 0x0
#define HDP_XDP_P2P_BAR7__FLUSH__SHIFT 0x10
#define HDP_XDP_P2P_BAR7__VALID__SHIFT 0x14
#define HDP_XDP_P2P_BAR7__ADDR_MASK 0x0000FFFFL
#define HDP_XDP_P2P_BAR7__FLUSH_MASK 0x000F0000L
#define HDP_XDP_P2P_BAR7__VALID_MASK 0x00100000L
//HDP_XDP_FLUSH_ARMED_STS
#define HDP_XDP_FLUSH_ARMED_STS__FLUSH_ARMED_STS__SHIFT 0x0
#define HDP_XDP_FLUSH_ARMED_STS__FLUSH_ARMED_STS_MASK 0xFFFFFFFFL
//HDP_XDP_FLUSH_CNTR0_STS
#define HDP_XDP_FLUSH_CNTR0_STS__FLUSH_CNTR0_STS__SHIFT 0x0
#define HDP_XDP_FLUSH_CNTR0_STS__FLUSH_CNTR0_STS_MASK 0x03FFFFFFL
//HDP_XDP_BUSY_STS
#define HDP_XDP_BUSY_STS__BUSY_BITS_0__SHIFT 0x0
#define HDP_XDP_BUSY_STS__BUSY_BITS_1__SHIFT 0x1
#define HDP_XDP_BUSY_STS__BUSY_BITS_2__SHIFT 0x2
#define HDP_XDP_BUSY_STS__BUSY_BITS_3__SHIFT 0x3
#define HDP_XDP_BUSY_STS__BUSY_BITS_4__SHIFT 0x4
#define HDP_XDP_BUSY_STS__BUSY_BITS_5__SHIFT 0x5
#define HDP_XDP_BUSY_STS__BUSY_BITS_6__SHIFT 0x6
#define HDP_XDP_BUSY_STS__BUSY_BITS_7__SHIFT 0x7
#define HDP_XDP_BUSY_STS__BUSY_BITS_8__SHIFT 0x8
#define HDP_XDP_BUSY_STS__BUSY_BITS_9__SHIFT 0x9
#define HDP_XDP_BUSY_STS__BUSY_BITS_10__SHIFT 0xa
#define HDP_XDP_BUSY_STS__BUSY_BITS_11__SHIFT 0xb
#define HDP_XDP_BUSY_STS__BUSY_BITS_12__SHIFT 0xc
#define HDP_XDP_BUSY_STS__BUSY_BITS_13__SHIFT 0xd
#define HDP_XDP_BUSY_STS__BUSY_BITS_14__SHIFT 0xe
#define HDP_XDP_BUSY_STS__BUSY_BITS_15__SHIFT 0xf
#define HDP_XDP_BUSY_STS__BUSY_BITS_16__SHIFT 0x10
#define HDP_XDP_BUSY_STS__BUSY_BITS_17__SHIFT 0x11
#define HDP_XDP_BUSY_STS__BUSY_BITS_18__SHIFT 0x12
#define HDP_XDP_BUSY_STS__BUSY_BITS_19__SHIFT 0x13
#define HDP_XDP_BUSY_STS__BUSY_BITS_20__SHIFT 0x14
#define HDP_XDP_BUSY_STS__BUSY_BITS_21__SHIFT 0x15
#define HDP_XDP_BUSY_STS__BUSY_BITS_22__SHIFT 0x16
#define HDP_XDP_BUSY_STS__BUSY_BITS_23__SHIFT 0x17
#define HDP_XDP_BUSY_STS__Z_FENCE_BIT__SHIFT 0x18
#define HDP_XDP_BUSY_STS__BUSY_BITS_0_MASK 0x00000001L
#define HDP_XDP_BUSY_STS__BUSY_BITS_1_MASK 0x00000002L
#define HDP_XDP_BUSY_STS__BUSY_BITS_2_MASK 0x00000004L
#define HDP_XDP_BUSY_STS__BUSY_BITS_3_MASK 0x00000008L
#define HDP_XDP_BUSY_STS__BUSY_BITS_4_MASK 0x00000010L
#define HDP_XDP_BUSY_STS__BUSY_BITS_5_MASK 0x00000020L
#define HDP_XDP_BUSY_STS__BUSY_BITS_6_MASK 0x00000040L
#define HDP_XDP_BUSY_STS__BUSY_BITS_7_MASK 0x00000080L
#define HDP_XDP_BUSY_STS__BUSY_BITS_8_MASK 0x00000100L
#define HDP_XDP_BUSY_STS__BUSY_BITS_9_MASK 0x00000200L
#define HDP_XDP_BUSY_STS__BUSY_BITS_10_MASK 0x00000400L
#define HDP_XDP_BUSY_STS__BUSY_BITS_11_MASK 0x00000800L
#define HDP_XDP_BUSY_STS__BUSY_BITS_12_MASK 0x00001000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_13_MASK 0x00002000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_14_MASK 0x00004000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_15_MASK 0x00008000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_16_MASK 0x00010000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_17_MASK 0x00020000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_18_MASK 0x00040000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_19_MASK 0x00080000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_20_MASK 0x00100000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_21_MASK 0x00200000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_22_MASK 0x00400000L
#define HDP_XDP_BUSY_STS__BUSY_BITS_23_MASK 0x00800000L
#define HDP_XDP_BUSY_STS__Z_FENCE_BIT_MASK 0x01000000L
//HDP_XDP_STICKY
#define HDP_XDP_STICKY__STICKY_STS__SHIFT 0x0
#define HDP_XDP_STICKY__STICKY_W1C__SHIFT 0x10
#define HDP_XDP_STICKY__STICKY_STS_MASK 0x0000FFFFL
#define HDP_XDP_STICKY__STICKY_W1C_MASK 0xFFFF0000L
//HDP_XDP_CHKN
#define HDP_XDP_CHKN__CHKN_0_RSVD__SHIFT 0x0
#define HDP_XDP_CHKN__CHKN_1_RSVD__SHIFT 0x8
#define HDP_XDP_CHKN__CHKN_2_RSVD__SHIFT 0x10
#define HDP_XDP_CHKN__CHKN_3_RSVD__SHIFT 0x18
#define HDP_XDP_CHKN__CHKN_0_RSVD_MASK 0x000000FFL
#define HDP_XDP_CHKN__CHKN_1_RSVD_MASK 0x0000FF00L
#define HDP_XDP_CHKN__CHKN_2_RSVD_MASK 0x00FF0000L
#define HDP_XDP_CHKN__CHKN_3_RSVD_MASK 0xFF000000L
//HDP_XDP_BARS_ADDR_39_36
#define HDP_XDP_BARS_ADDR_39_36__BAR0_ADDR_39_36__SHIFT 0x0
#define HDP_XDP_BARS_ADDR_39_36__BAR1_ADDR_39_36__SHIFT 0x4
#define HDP_XDP_BARS_ADDR_39_36__BAR2_ADDR_39_36__SHIFT 0x8
#define HDP_XDP_BARS_ADDR_39_36__BAR3_ADDR_39_36__SHIFT 0xc
#define HDP_XDP_BARS_ADDR_39_36__BAR4_ADDR_39_36__SHIFT 0x10
#define HDP_XDP_BARS_ADDR_39_36__BAR5_ADDR_39_36__SHIFT 0x14
#define HDP_XDP_BARS_ADDR_39_36__BAR6_ADDR_39_36__SHIFT 0x18
#define HDP_XDP_BARS_ADDR_39_36__BAR7_ADDR_39_36__SHIFT 0x1c
#define HDP_XDP_BARS_ADDR_39_36__BAR0_ADDR_39_36_MASK 0x0000000FL
#define HDP_XDP_BARS_ADDR_39_36__BAR1_ADDR_39_36_MASK 0x000000F0L
#define HDP_XDP_BARS_ADDR_39_36__BAR2_ADDR_39_36_MASK 0x00000F00L
#define HDP_XDP_BARS_ADDR_39_36__BAR3_ADDR_39_36_MASK 0x0000F000L
#define HDP_XDP_BARS_ADDR_39_36__BAR4_ADDR_39_36_MASK 0x000F0000L
#define HDP_XDP_BARS_ADDR_39_36__BAR5_ADDR_39_36_MASK 0x00F00000L
#define HDP_XDP_BARS_ADDR_39_36__BAR6_ADDR_39_36_MASK 0x0F000000L
#define HDP_XDP_BARS_ADDR_39_36__BAR7_ADDR_39_36_MASK 0xF0000000L
//HDP_XDP_MC_VM_FB_LOCATION_BASE
#define HDP_XDP_MC_VM_FB_LOCATION_BASE__FB_BASE__SHIFT 0x0
#define HDP_XDP_MC_VM_FB_LOCATION_BASE__FB_BASE_MASK 0x03FFFFFFL
//HDP_XDP_MMHUB_ERROR
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_01__SHIFT 0x1
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_10__SHIFT 0x2
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_11__SHIFT 0x3
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_FED__SHIFT 0x4
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_01__SHIFT 0x5
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_10__SHIFT 0x6
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_11__SHIFT 0x7
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_01__SHIFT 0x9
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_10__SHIFT 0xa
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_11__SHIFT 0xb
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_FED__SHIFT 0xc
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_01__SHIFT 0xd
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_10__SHIFT 0xe
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_11__SHIFT 0xf
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_01__SHIFT 0x11
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_10__SHIFT 0x12
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_11__SHIFT 0x13
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_01__SHIFT 0x15
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_10__SHIFT 0x16
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_11__SHIFT 0x17
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_01_MASK 0x00000002L
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_10_MASK 0x00000004L
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_11_MASK 0x00000008L
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_FED_MASK 0x00000010L
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_01_MASK 0x00000020L
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_10_MASK 0x00000040L
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_11_MASK 0x00000080L
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_01_MASK 0x00000200L
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_10_MASK 0x00000400L
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_11_MASK 0x00000800L
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_FED_MASK 0x00001000L
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_01_MASK 0x00002000L
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_10_MASK 0x00004000L
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_11_MASK 0x00008000L
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_01_MASK 0x00020000L
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_10_MASK 0x00040000L
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_11_MASK 0x00080000L
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_01_MASK 0x00200000L
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_10_MASK 0x00400000L
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_11_MASK 0x00800000L
#endif

View File

@@ -6,7 +6,9 @@ from tinygrad.engine.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.helpers import Context
from tinygrad.dtype import dtypes
from tinygrad.ops import Ops
import json
from collections import OrderedDict
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
@@ -26,6 +28,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
@@ -54,60 +57,105 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
special_names[id(output.lazydata.base.realized)] = f'output{i}'
return run, special_names
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
cprog = ["#include <tgmath.h>"]
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str:
headers = ["#include <tgmath.h>"]
cprog = list(functions.values())
dtype_map = {dtypes.int: "int", dtypes.float: "float", dtypes.uchar: "unsigned char", dtypes.char: "signed char", dtypes.half: "__fp16", dtypes.uint: "unsigned int"}
inputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in input_names + list(symbolic_vars.values())]
outputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in output_names]
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
if not wasm:
for name,cl in bufs_to_save.items():
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(headers + cprog)
else:
if bufs_to_save:
headers += ["#include <stddef.h>"]
bufs_to_save = {k:v for k,v in bufs.items() if v[2] in weight_names} # causes random seeds to be set as zeroes, not exported as a model weight
buf_to_name = OrderedDict((buf_name, {"name": weight_names[data[2]], "idx": i}) for i, (buf_name, data) in enumerate(bufs_to_save.items()))
cprog.append(f"void* bufs[{len(buf_to_name)}];")
cprog.append(f"""void set_buf(size_t index, void* ptr) {{\n bufs[index] = ptr;\n}}""")
inputs = ", ".join([f'float* {input}' for input in input_names])
outputs = ", ".join([f'float* {output}' for output in output_names])
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
cprog += list(functions.values())
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(cprog)
for name in set(bufs.keys()) - set(bufs_to_save.keys()) - set(input_names + output_names):
n_bytes, dtype, _ = bufs[name]
cprog += [f"{dtype_map[dtype]} {name}[{n_bytes // dtype.itemsize}];"]
cprog += [f"void net({forward_args})"] + ["{"]
get_weight_ptr = lambda x: f"({dtype_map[bufs_to_save[x][1]]} *)bufs[{buf_to_name[x]['idx']}]" if x in bufs_to_save else x
cprog += [f" {name}({', '.join(map(get_weight_ptr, args))});" for (name, args, _global_size, _local_size) in statements] + ["}"]
weightMapping = "" if not bufs_to_save else f"""\nconst weightNames = [{", ".join([f'"{weight_name}"' for weight_name in [v["name"] for v in buf_to_name.values()]])}];
const {model_name}_name_to_id = Object.fromEntries(weightNames.map((name, index) => [name, index]));\n"""
top = f"""import {model_name}Module from './{model_name}.js'{weightMapping}"""
whitespace = "\n "
js_wrapper = f"""{top}\nvar {model_name} = async function() {{
const wasm = await {model_name}Module();
{whitespace.join(f"const {name}Ptr = wasm._malloc({n_bytes});" for name, _, n_bytes in outputs+inputs if name not in symbolic_vars.values())}
return {{
run: ({",".join(name for name,_,_ in inputs)}) => {{
{(whitespace + " ").join(f"wasm.HEAPU8.set({name}, {name}Ptr);" for name,_,_ in inputs if name not in symbolic_vars.values())}
wasm._net({", ".join(f"{name}{'Ptr' if name not in symbolic_vars.values() else ''}" for name,_,_ in outputs+inputs)});
{(whitespace + " ").join(f"const {name} = wasm.HEAPU8.slice({name}Ptr, {name}Ptr + {n_bytes});" for name,_,n_bytes in outputs)}
return [{", ".join(f"{name}" for name,_,_ in outputs)}];
}},
wasm: wasm
}}
}}\nexport {{ {model_name}, {model_name}_name_to_id }};"""
return '\n'.join(headers + cprog), js_wrapper
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
exported_name = "model" if model_name == None else model_name
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]:
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names += list(symbolic_vars.values())
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage"
create_bind_group_layouts = ",".join([
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)])
)
for _, (_, args, _, _) in enumerate(statements)
])
layouts = f"const layouts=[{create_bind_group_layouts}]"
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf"
map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])"
_bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
const {exported_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
const getTensorMetadata = (safetensorBuffer) => {{
getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};\n""" if not stream_weights else ""
return f"""
const {model_name} = (() => {{
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}};
{getTensorMetadata}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createUniformBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}})
}}
const createInfinityUniformBuf = (device) => {{
const size = 4;
const buf = device.createBuffer({{
@@ -121,9 +169,8 @@ const createInfinityUniformBuf = (device) => {{
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
const buf = device.createBuffer({{ size, usage: GPUBufferUsage.STORAGE{" | GPUBufferUsage.COPY_DST" if stream_weights else ", mappedAtCreation: true"} }});
{"data.bytes = buf;" if stream_weights else "new Uint8Array(buf.getMappedRange()).set(data); buf.unmap();"}
return buf;
}};
@@ -145,8 +192,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor
{kernel_code}
const setupNet = async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor);
const setupNet = async (device, {"state_dict" if stream_weights else "safetensor"}) => {{
{"const metadata = getTensorMetadata(safetensor);" if not stream_weights else ""}
const infinityBuf = createInfinityUniformBuf(device);
{layouts}
@@ -185,12 +232,12 @@ const setupNet = async (device, safetensor) => {{
}}
}}
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
return {{ load }};
return {{ load, setupNet }};
}})();
export default {exported_name};
export default {model_name};
"""
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
@@ -198,11 +245,30 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
# handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad
symbolic_vars = OrderedDict()
for i, (_, args, global_size, _) in enumerate(statements):
for j, var in enumerate(args):
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
if var not in symbolic_vars:
symbolic_vars[var] = var.arg[0]
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
statements[i][1][j] = symbolic_vars[var]
if global_size:
for j, dim in enumerate(global_size):
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
prg = ""
if target == "clang":
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
elif target == "wasm":
return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True)
elif target == "webgpu":
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights)
else:
prg = json.dumps({
"backend": Device.DEFAULT,

File diff suppressed because it is too large Load Diff

View File

@@ -195,6 +195,7 @@ def get_onnx_ops():
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
if auto_pad == "VALID": return [0]*(len(k_)*2)
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
@@ -673,7 +674,8 @@ def get_onnx_ops():
x_sh = list(x.shape)
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
if indices.ndim > 1: indices = indices.flatten()
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
indices = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
indices = [x_sh[axis]+x if x<0 else x for x in indices]
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot

View File

@@ -174,6 +174,7 @@ decomps = [
aten.threshold_backward,
aten.softplus_backward,
aten.elu, # elu has a scale + input_scale param
aten.elu_backward,
aten.softplus,
aten.threshold,
aten.nll_loss_forward,
@@ -270,8 +271,8 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
# TODO: this might result in overflow issues
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
# TODO: support this in tinygrad
"aten.bitwise_left_shift.Tensor_out": lambda input,other: Tensor(input << other.numpy()),
"aten.bitwise_right_shift.Tensor_out": lambda input,other: Tensor(input >> other.numpy()),
"aten.bitwise_left_shift.Tensor_out": lambda x,y: x*(2**y),
"aten.bitwise_right_shift.Tensor_out": lambda x,y: x//(2**y),
# not in tinygrad. are there decomps for these?
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
"aten.log1p.out": lambda self: (self+1).log(),
@@ -326,11 +327,12 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.var.correction": Tensor.var,
"aten.var_mean.correction": Tensor.var_mean,
# NOTE: axis=[] in torch means all, change tinygrad?
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
"aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None,out=None:
out.replace(self.sum(axis if axis is None or len(axis) else None, keepdim,
acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None), allow_shape_mismatch=True),
"aten.scatter.value": Tensor.scatter,
"aten.scatter.value_reduce": Tensor.scatter,
"aten.gather": Tensor.gather,
"aten.gather": lambda self, dim, index: self.gather(dim, index.cast(dtypes.int)),
"aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
@@ -343,10 +345,11 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
# these don't work in out form, they have size 0
"aten.abs": Tensor.abs,
"aten.logical_not": Tensor.logical_not,
"aten.logical_or_": lambda x, y: x.assign(x | y),
"aten.multinomial": Tensor.multinomial,
"aten.pad": Tensor.pad,
"aten.reflection_pad2d": functools.partial(Tensor.pad, mode="reflect"),
"aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)),
"aten.masked_fill_.Scalar": lambda self, mask, value: self.assign(self.masked_fill(mask, value)),
"aten.masked_fill.Scalar": Tensor.masked_fill,
"aten.masked_fill.Tensor": Tensor.masked_fill,
"aten.all": Tensor.all,
@@ -361,6 +364,14 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.atanh": Tensor.atanh,
"aten.fill_.Tensor": Tensor.full,
"aten.flip": Tensor.flip,
"aten.scatter_reduce.two": Tensor.scatter_reduce,
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True),
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
"aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)),
"aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value),
"aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim),
"aten.squeeze.dim": Tensor.squeeze,
"aten.unsqueeze": Tensor.unsqueeze,
"aten.roll": Tensor.roll,
@@ -368,6 +379,9 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
"aten.repeat": Tensor.repeat,
"aten.lerp.Tensor": Tensor.lerp,
"aten.expand": Tensor.expand,
"aten.t": Tensor.transpose,
"aten.detach": Tensor.detach,
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
}}
def wrap_fxn(k,f):
@@ -388,11 +402,11 @@ for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::")
if TORCH_DEBUG:
from torch.utils._python_dispatch import TorchDispatchMode
class DispatchLog(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args, kwargs=None):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
print(f"Dispatch Log: {func}")
return func(*args, **(kwargs or {}))
DispatchLog().__enter__()
(_dispatch_log:=DispatchLog()).__enter__() # NOTE: must be kept alive
# NOTE: patch torch optimizer step to avoid continously growing the computation graph
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):

View File

@@ -75,7 +75,7 @@ class TestTorchBackend(unittest.TestCase):
a = torch.ones(4, device=device)
b = torch.ones(4, device=device)
c = a == b
print(c.cpu().numpy())
print(c.cpu())
def test_maxpool2d_backward(self):
x = torch.arange(3*3, device=device).reshape(1, 1, 3, 3).requires_grad_(True)
@@ -86,10 +86,10 @@ class TestTorchBackend(unittest.TestCase):
x = torch.zeros(4, device=device, dtype=torch.int64)
y = torch.ones(4, device=device, dtype=torch.float32).to(dtype=torch.int64)
res1 = x ^ y # an operation that only works on int types
print(res1.cpu().numpy())
print(res1.cpu())
y = y.cpu().float().to(device=device, dtype=torch.int64)
res2 = x ^ y
print(res2.cpu().numpy())
print(res2.cpu())
@unittest.skip("meh")
def test_str(self):

View File

@@ -30,7 +30,9 @@ def run(sz, n_gpus=6, iters=10, use_ring=False):
with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters)
def main():
ONLY_RING = getenv("ONLY_RING", 0)
n_gpus = getenv("GPUS", 6)
iters = getenv("ITERS", 10)
if getenv("BENCHMARK_SPLIT"):
l, r = 0, 512
@@ -44,10 +46,10 @@ def main():
else:
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus)
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus)
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters)
if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters)
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
if __name__ == "__main__":
main()

View File

@@ -72,7 +72,7 @@ class AMPTFuzzer:
pattern = self.generate_pattern(ptr, size)
pages = self.fill_memory(ptr, size, pattern)
self.allocations[ptr] = (size, pattern, pages)
self.allocations[ptr.va_addr] = (size, pattern, pages, ptr)
self.alloc_payload += size
print(f"Allocated {size} bytes at {ptr.va_addr:x}, pattern: {pattern:02x}")
return ptr
@@ -81,15 +81,15 @@ class AMPTFuzzer:
if not self.allocations: return False
ptr = random.choice(list(self.allocations.keys()))
size, pattern, pages = self.allocations[ptr]
size, pattern, pages, vm = self.allocations[ptr]
# Verify pattern before freeing
if not self.verify_memory(pages, pattern):
raise RuntimeError(f"Memory corruption detected at {ptr.va_addr:x}!")
raise RuntimeError(f"Memory corruption detected at {vm.va_addr:x}!")
print(f"Freeing {size} bytes at {ptr.va_addr:x}, pattern verified: {pattern:02x}")
print(f"Freeing {size} bytes at {vm.va_addr:x}, pattern verified: {pattern:02x}")
self.alloc_payload -= size
self.d.mm.vfree(ptr)
self.d.mm.vfree(vm)
del self.allocations[ptr]
return True

View File

@@ -153,5 +153,21 @@ class TestAMPageTable(unittest.TestCase):
mm0.map_range(0x1000000, 2 << 20, paddrs=[(0x10000, 2 << 20)])
mm0.unmap_range(0x1000000, 2 << 20)
def test_frag_size(self):
mm0 = self.d[0].mm
def must_cover_checker(va, sz):
ans = (1 << (mm0._frag_size(va=va, sz=sz, must_cover=True) + 12))
assert va % ans == 0 and sz % ans == 0 and (va % (2 * ans) != 0 or sz % (2 * ans) != 0), f"va {va:#x} sz {sz:#x} ans {ans:#x}"
def not_cover_checker(va, sz):
ans = (1 << (mm0._frag_size(va=va, sz=sz, must_cover=False) + 12))
assert va % ans == 0 and ans <= sz and (va % (2 * ans) != 0 or (2 * ans) > sz), f"va {va:#x} sz {sz:#x} ans {ans:#x}"
for va, sz in [(0x0, 0x1000), (0x1000, 0x2000), (0x1000, 0x3000), (0x2000, 0x2000), (0x4000, 0x8000), (0x8000, 0x4000), (0x10000, 0x4000),
(0x0, 0x4000), (0x10000, 0x4000), (0x10000, 0x40000), (0x10001000, 0x40000), (0x100001000, 0x3000)]:
must_cover_checker(va, sz)
not_cover_checker(va, sz)
if __name__ == "__main__":
unittest.main()

View File

@@ -27,6 +27,27 @@ class TestMainOnnxOps(TestOnnxOps):
outputs = ["out"]
self.helper_test_single_op("Reshape", inputs, attributes, outputs)
def test_conv(self):
# test VALID auto_pad
inputs = {
"x": np.random.randn(1, 3, 384, 384).astype(np.float32),
"w": np.random.randn(1152, 3, 14, 14).astype(np.float32),
"b": np.random.randn(1152).astype(np.float32)
}
attributes = {'auto_pad': 'VALID', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (14, 14), 'strides': (14, 14)}
outputs = ["y"]
self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4)
def test_gather(self):
# test const negative indices
inputs = {
"input": np.random.randn(1, 3, 3).astype(np.float32),
"indices": np.array(-2, dtype=np.int64),
}
attributes = {'axis': 1}
outputs = ["y"]
self.helper_test_single_op("Gather", inputs, attributes, outputs)
def test_quantize_linear(self):
test_cases = [
{"test_case": "round_half_to_even", "qdtype": np.int8, "qzero_point": 0, "x": [-1.5, -0.5, 0.5, 1.5], "scale": 1.0},

View File

@@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase):
# theoretical is nv_tflops=165, amd_tflops=123
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=80)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=65)
# theoretical is nv_gbs=1008, amd_gbs=960
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=760)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750)
if __name__ == '__main__':
unittest.main()

View File

@@ -5,9 +5,8 @@ import tinygrad.runtime.autogen.amd_gpu as amd_gpu
SDMA_MAX_COPY_SIZE = 0x400000
BASE_ADDR = 0x00001260
PACKET3_SET_SH_REG_START = 0x2c00
SUB = PACKET3_SET_SH_REG_START - BASE_ADDR
SUB = PACKET3_SET_SH_REG_START - amd_gpu.GC_BASE__INST0_SEG0
regCOMPUTE_PGM_LO = 0x1bac - SUB
regCOMPUTE_USER_DATA_0 = 0x1be0 - SUB

View File

@@ -343,6 +343,11 @@ class TestDiskTensor(unittest.TestCase):
on_dev = t.to(Device.DEFAULT).realize()
np.testing.assert_equal(on_dev.numpy(), t.numpy())
@unittest.skipUnless(OSX, "seems to only be an issue on macOS with file size >2 GiB")
def test_copy_to_cpu_not_truncated(self):
with open((fn:=temp("dt_copy_to_cpu_not_truncated")), "wb") as f: f.write(b'\x01' * (size := int(2 * 1024**3)) + (test := b"test"))
x = Tensor.empty(size + len(test), dtype=dtypes.uint8, device=f"disk:{fn}").to("CPU").realize()
assert x[size:].data().tobytes() == test
class TestPathTensor(unittest.TestCase):
def setUp(self):

View File

@@ -359,33 +359,28 @@ fix_kernel_ops = PatternMatcher([
# remove CONTIGUOUS/DEVICE
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])
# TODO: replace this with the KERNEL UOp
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
# substitute kernel sources for the target buffer
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
# add buffer ops
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# fix_kernel_ops
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
# create subbuffer
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
return k.replace(arg=Kernel(ast, k.arg.metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -395,6 +390,12 @@ if CAPTURE_PROCESS_REPLAY:
# **** schedule creation and toposort
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
# merge_views + sym
@@ -408,7 +409,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
realize_map = group_realizes(sink)
# map tensor metadata to simplified ops
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
# create kernels
# create_kernels
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
sched_sink = kernel_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)
@@ -459,7 +460,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
var_vals: dict[Variable, int] = {}
while queue:
u = queue.popleft()
schedule.append(schedule_uop(u, var_vals))
# TODO: move this to create_kernels
k = fix_kernel_ast(u.src[1], var_vals)
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
# increment the refcount of the target buf (this is required by the JIT and memory planner)
u.buf_uop.buffer.ref(1)
for x in children.get(u, []):

View File

@@ -82,7 +82,7 @@ class LARS(Optimizer):
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
else: r = 1.0
g = g + self.wd * t.detach()
# classic momentum does post learning rate update
@@ -141,7 +141,7 @@ class LAMB(Optimizer):
if not self.adam:
r1 = t.detach().square().sum().sqrt()
r2 = up.square().sum().sqrt()
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
else:
r = 1.0
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))

View File

@@ -513,6 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
@property
def buf_uop(self) -> UOp:
if self.op is Ops.BUFFER: return self
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
return self.src[0].base
@property

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -22,8 +22,8 @@ WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
COMPUTE_SHADER_EN, FORCE_START_AT_000, CS_W32_EN = (1 << 0), (1 << 2), (1 << 15)
def gfxreg(reg): return reg + 0x00001260 - amd_gpu.PACKET3_SET_SH_REG_START
def nbioreg(reg): return reg + 0x00000d20 # NBIO_BASE__INST0_SEG2
def gfxreg(reg): return reg + amd_gpu.GC_BASE__INST0_SEG0 - amd_gpu.PACKET3_SET_SH_REG_START
def nbioreg(reg): return reg + amd_gpu.NBIO_BASE__INST0_SEG2
class AMDSignal(HCQSignal):
def __init__(self, base_addr:int|None=None, **kwargs):
@@ -148,7 +148,7 @@ class AMDComputeQueue(HWQueue):
for i, value in enumerate(cmds): dev.compute_queue.ring[(dev.compute_queue.put_value + i) % len(dev.compute_queue.ring)] = value
dev.compute_queue.put_value += len(cmds)
dev.compute_queue.signal_doorbell()
dev.compute_queue.signal_doorbell(dev)
class AMDCopyQueue(HWQueue):
def __init__(self, max_copy_size=0x40000000):
@@ -232,7 +232,7 @@ class AMDCopyQueue(HWQueue):
dev.sdma_queue.ring[0:rem_packet_cnt] = array.array('I', cmds[tail_blit_dword:])
dev.sdma_queue.put_value += rem_packet_cnt * 4
dev.sdma_queue.signal_doorbell()
dev.sdma_queue.signal_doorbell(dev)
class AMDProgram(HCQProgram):
def __init__(self, dev:AMDDevice, name:str, lib:bytes):
@@ -242,26 +242,26 @@ class AMDProgram(HCQProgram):
image, sections, _ = elf_loader(self.lib)
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferSpec(cpu_access=True, nolru=True))
ctypes.memmove(self.lib_gpu.va_addr, mv_address(image), image.nbytes)
entry_point = min(sh.header.sh_addr for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS and sh.header.sh_flags & libc.SHF_ALLOC)
self.group_segment_size = image[entry_point:entry_point+4].cast("I")[0]
self.private_segment_size = image[entry_point+4:entry_point+8].cast("I")[0]
self.kernargs_segment_size = image[entry_point+8:entry_point+12].cast("I")[0]
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
text_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".text"), -1)
assert rodata_entry >= 0 and text_entry >= 0, ".text or .rodata section not found"
self.group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
self.private_segment_size = image[rodata_entry+4:rodata_entry+8].cast("I")[0]
self.kernargs_segment_size = image[rodata_entry+8:rodata_entry+12].cast("I")[0]
lds_size = ((self.group_segment_size + 511) // 512) & 0x1FF
if lds_size > (self.dev.dev_iface.props['lds_size_in_kb'] * 1024) // 512: raise RuntimeError("Too many resources requested: group_segment_size")
# Ensure scratch size
self.dev._ensure_has_local_memory(self.private_segment_size)
code = hsa.amd_kernel_code_t.from_address(self.lib_gpu.va_addr + entry_point) # NOTE: this is wrong, it's not this object
code = hsa.amd_kernel_code_t.from_address(self.lib_gpu.va_addr + rodata_entry) # NOTE: this is wrong, it's not this object
assert code.kernel_code_properties & 0x400 == 0x400 # ENABLE_WAVEFRONT_SIZE32
# Set rsrc1.priv=1 on gfx11 to workaround cwsr.
self.rsrc1: int = code.compute_pgm_rsrc1 | ((1 << 20) if 110000 <= self.dev.target < 120000 else 0)
self.rsrc2: int = code.compute_pgm_rsrc2 | (lds_size << 15)
self.prog_addr: int = self.lib_gpu.va_addr + entry_point + code.kernel_code_entry_byte_offset
self.prog_addr: int = self.lib_gpu.va_addr + rodata_entry + code.kernel_code_entry_byte_offset
if code.kernel_code_entry_byte_offset == 0: self.prog_addr = self.lib_gpu.va_addr + text_entry
# Some programs use hsa_kernel_dispatch_packet_t to read workgroup sizes during execution.
# The packet is represented as a pointer and set up in SGPRs. Space for the packet is allocated as part of the kernel arguments.
self.enable_dispatch_ptr: int = code.kernel_code_properties & hsa.AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_DISPATCH_PTR
@@ -293,11 +293,14 @@ class AMDQueueDesc:
doorbell: memoryview
put_value: int = 0
def signal_doorbell(self):
def signal_doorbell(self, dev):
self.write_ptr[0] = self.put_value
# Ensure all prior writes are visible to the GPU.
if CPUProgram.atomic_lib is not None: CPUProgram.atomic_lib.atomic_thread_fence(__ATOMIC_SEQ_CST:=5)
# Flush hdp if queue is in dev mem.
if dev.driverless and getenv("AMD_ALLOC_QUEUE_DEV_MEM", 1): dev.dev_iface.adev.gmc.flush_hdp()
self.doorbell[0] = self.put_value
class KFDIface:
@@ -450,7 +453,8 @@ class PCIIface:
HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver/unbind", os.O_WRONLY).write(self.pcibus)
supported_sizes = int(HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDONLY).read(), 16)
HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDWR).write(str(supported_sizes.bit_length() - 1))
try: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDWR).write(str(supported_sizes.bit_length() - 1))
except OSError as e: raise RuntimeError(f"Cannot resize BAR: {e}. Ensure the resizable BAR option is enabled on your system.") from e
# Try to init vfio. Use it if success.
if PCIIface.vfio:
@@ -580,7 +584,7 @@ class AMDDevice(HCQCompiled):
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000
vgpr_size_per_cu = 0x60000 if self.target in {110000, 110001, 120000, 120001} else 0x40000
wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * (self.max_cu_id + 1), mmap.PAGESIZE)
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE)
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE) if self.target//10000 != 10 else 0x7000
debug_memory_size = round_up((self.max_cu_id + 1) * (self.max_wave_id + 1) * 32, 64)
self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x800000, ctx_save_restore_size=wg_data_size + ctl_stack_size,

View File

@@ -84,7 +84,8 @@ class DiskAllocator(Allocator):
# OSX doesn't seem great at mmap, this is faster
with io.FileIO(self.dev.fd, "a+b", closefd=False) as fo:
fo.seek(src.offset)
fo.readinto(dest)
bytes_read = 0
while (n := fo.readinto(dest[bytes_read:])) is not None and n > 0: bytes_read += n
else:
dest[:] = src._buf()

View File

@@ -175,6 +175,15 @@ class AMMemoryManager:
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
def _frag_size(self, va, sz, must_cover=True):
"""
Calculate the tlb fragment size for a given virtual address and size.
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
Fragment 0 is 4KB, 1 is 8KB and so on.
"""
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
@@ -185,8 +194,8 @@ class AMMemoryManager:
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
for pte_off in range(pte_cnt):
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
# Invalidate TLB after mappings.
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
@@ -212,12 +221,21 @@ class AMMemoryManager:
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
else:
paddrs = []
try:
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
except MemoryError:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
for off, _, _, seg_cnt, seg_size in ctx.next(size):
while seg_cnt > 0:
# Try to allocate as long segment (power of 2) as possible
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, seg_cnt*seg_size) + 12), None
while cont_seg_sz >= seg_size:
try: paddr = self.palloc(cont_seg_sz, zero=True)
except MemoryError: cont_seg_sz //= 2
else: break
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
else:
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
raise MemoryError(f"Failed to allocate contigous {cont_seg_sz=:#x} bytes (size={size:#x})")
seg_cnt, off = seg_cnt - cont_seg_sz // seg_size, off + cont_seg_sz
return self.map_range(va, size, paddrs, uncached=uncached)
@@ -295,7 +313,7 @@ class AMDev:
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
self.smu.set_clocks(level=-1) # last level, max perf.
self.gfx.set_clockgating_state()
for ip in [self.soc21, self.gfx]: ip.set_clockgating_state()
self.reg("regSCRATCH_REG7").write(am_version)
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
@@ -382,7 +400,9 @@ class AMDev:
def _build_regs(self):
mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
for base, module in mods:
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))

View File

@@ -7,11 +7,13 @@ class AM_IP:
def __init__(self, adev): self.adev = adev
def init(self): raise NotImplementedError("IP block init must be implemeted")
def fini(self): pass
def set_clockgating_state(self): pass
class AM_SOC21(AM_IP):
def init(self):
self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
def set_clockgating_state(self): self.adev.regHDP_MEM_POWER_CTRL.update(atomic_mem_power_ctrl_en=1, atomic_mem_power_ds_en=1)
class AM_GMC(AM_IP):
def __init__(self, adev):
@@ -34,7 +36,7 @@ class AM_GMC(AM_IP):
def init(self): self.init_hub("MM")
def flush_hdp(self): self.adev.regBIF_BX_PF0_GPU_HDP_FLUSH_REQ.write(0xffffffff)
def flush_hdp(self): self.adev.wreg(self.adev.reg("regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL").read() // 4, 0x0)
def flush_tlb(self, ip:Literal["MM", "GC"], vmid, flush_type=0):
self.flush_hdp()

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
from contextlib import ContextDecorator
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
@@ -46,11 +46,10 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
# **** Tensor helper functions ****
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None):
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp:
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
# fake realize
@@ -97,7 +96,7 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
values = values * mask
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
@@ -293,8 +292,8 @@ class Tensor(SimpleMathTrait):
if 0 in self.shape: return memoryview(bytearray(0))
# NOTE: this realizes on the object from as_buffer being a Python object
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
buf = cast(UOp, cpu.lazydata).base.realized
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
buf = cpu.lazydata.base.realized
assert buf is not None, f"{cpu.lazydata.base} was not realized"
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
@@ -310,7 +309,7 @@ class Tensor(SimpleMathTrait):
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
def item(self) -> ConstType:
"""
@@ -376,7 +375,7 @@ class Tensor(SimpleMathTrait):
if self.grad is not None: ret.grad = self.grad.to(device)
return ret
def to_(self, device:str|tuple[str, ...]|None):
def to_(self, device:str|tuple[str, ...]|None) -> Tensor:
"""
Moves the tensor to the given device in place.
"""
@@ -398,7 +397,7 @@ class Tensor(SimpleMathTrait):
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:int|None=None):
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
"""
Shards the tensor across the given devices in place.
"""
@@ -415,7 +414,7 @@ class Tensor(SimpleMathTrait):
# ***** creation entrypoint *****
@staticmethod
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs):
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs) -> Tensor:
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
if isinstance(device, tuple):
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
@@ -423,7 +422,7 @@ class Tensor(SimpleMathTrait):
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
@staticmethod
def empty(*shape, **kwargs):
def empty(*shape, **kwargs) -> Tensor:
"""
Creates an empty tensor with the given shape.
@@ -468,7 +467,7 @@ class Tensor(SimpleMathTrait):
_device_seeds: dict[str, Tensor] = {}
_device_rng_counters: dict[str, Tensor] = {}
@staticmethod
def manual_seed(seed=0):
def manual_seed(seed=0) -> None:
"""
Sets the seed for random operations.
@@ -486,7 +485,7 @@ class Tensor(SimpleMathTrait):
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor:
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
@@ -1069,7 +1068,7 @@ class Tensor(SimpleMathTrait):
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
if mode == "constant":
def _constant(x:Tensor,px,v):
def _constant(x:Tensor,px,v) -> Tensor:
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
@@ -1464,7 +1463,7 @@ class Tensor(SimpleMathTrait):
order[dim0], order[dim1] = order[dim1], order[dim0]
return self.permute(order)
def flatten(self, start_dim=0, end_dim=-1):
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
"""
Flattens the tensor by reshaping it into a one-dimensional tensor.
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
@@ -1480,7 +1479,7 @@ class Tensor(SimpleMathTrait):
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
def unflatten(self, dim:int, sizes:tuple[int,...]):
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
"""
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
@@ -1565,7 +1564,7 @@ class Tensor(SimpleMathTrait):
ret = self._apply_uop(UOp.r, op=op, axis=axis)
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Returns the sum of the elements of the tensor along the specified axis or axes.
@@ -1592,7 +1591,7 @@ class Tensor(SimpleMathTrait):
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
"""
Returns the product of the elements of the tensor along the specified axis or axes.
@@ -1618,7 +1617,7 @@ class Tensor(SimpleMathTrait):
"""
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
def max(self, axis:int|Sequence[int]|None=None, keepdim=False):
def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the maximum value of the tensor along the specified axis or axes.
@@ -1641,9 +1640,9 @@ class Tensor(SimpleMathTrait):
"""
return self._reduce(Ops.MAX, axis, keepdim)
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
def min(self, axis:int|Sequence[int]|None=None, keepdim=False):
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the minimum value of the tensor along the specified axis or axes.
@@ -1666,7 +1665,7 @@ class Tensor(SimpleMathTrait):
"""
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
def any(self, axis:int|Sequence[int]|None=None, keepdim=False):
def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Tests if any element evaluates to `True` along the specified axis or axes.
@@ -1688,7 +1687,7 @@ class Tensor(SimpleMathTrait):
"""
return self.bool().max(axis, keepdim)
def all(self, axis:int|Sequence[int]|None=None, keepdim=False):
def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Tests if all element evaluates to `True` along the specified axis or axes.
@@ -1730,7 +1729,7 @@ class Tensor(SimpleMathTrait):
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
return is_finite_close | is_infinite_close | is_nan_close
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False):
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
"""
Returns the mean value of the tensor along the specified axis or axes.
@@ -1754,9 +1753,10 @@ class Tensor(SimpleMathTrait):
"""
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \
.cast(output_dtype)
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
Returns the variance of the tensor along the specified axis or axes.
@@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait):
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
"""
Calculates the variance and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
@@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
"""
Returns the standard deviation of the tensor along the specified axis or axes.
@@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait):
"""
return self.var(axis, keepdim, correction).sqrt()
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
"""
Calculates the standard deviation and mean over the dimensions specified by dim.
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
@@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait):
"""
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
def _softmax(self, axis, dtype:DTypeLike|None=None):
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
m = self - self.max(axis=axis, keepdim=True).detach()
if dtype is not None: m = m.cast(dtype)
e = m.exp()
return m, e, e.sum(axis=axis, keepdim=True)
def softmax(self, axis=-1, dtype:DTypeLike|None=None):
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the softmax function to the tensor along the specified axis.
@@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait):
_, e, ss = self._softmax(axis, dtype)
return e.div(ss)
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None):
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
"""
Applies the log-softmax function to the tensor along the specified axis.
@@ -1892,7 +1892,7 @@ class Tensor(SimpleMathTrait):
m, _, ss = self._softmax(axis, dtype)
return m - ss.log()
def logsumexp(self, axis=None, keepdim=False):
def logsumexp(self, axis=None, keepdim=False) -> Tensor:
"""
Computes the log-sum-exp of the tensor along the specified axis or axes.
@@ -1919,7 +1919,7 @@ class Tensor(SimpleMathTrait):
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
def logcumsumexp(self, axis=0):
def logcumsumexp(self, axis=0) -> Tensor:
"""
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
@@ -1954,7 +1954,7 @@ class Tensor(SimpleMathTrait):
ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
return ret.reshape(*x.shape).transpose(-1, axis)
def argmax(self, axis=None, keepdim=False):
def argmax(self, axis=None, keepdim=False) -> Tensor:
"""
Returns the indices of the maximum value of the tensor along the specified axis.
@@ -1981,7 +1981,7 @@ class Tensor(SimpleMathTrait):
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
def argmin(self, axis=None, keepdim=False):
def argmin(self, axis=None, keepdim=False) -> Tensor:
"""
Returns the indices of the minimum value of the tensor along the specified axis.
@@ -2353,7 +2353,7 @@ class Tensor(SimpleMathTrait):
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
def cumsum(self, axis:int=0) -> Tensor:
@@ -2565,7 +2565,7 @@ class Tensor(SimpleMathTrait):
# ***** unary ops *****
def logical_not(self):
def logical_not(self) -> Tensor:
"""
Computes the logical NOT of the tensor element-wise.
@@ -2574,7 +2574,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
def neg(self):
def neg(self) -> Tensor:
"""
Negates the tensor element-wise.
@@ -2583,17 +2584,20 @@ class Tensor(SimpleMathTrait):
```
"""
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
def contiguous(self):
def contiguous(self) -> Tensor:
"""
Returns a contiguous tensor.
"""
return self._apply_uop(UOp.contiguous)
def contiguous_backward(self):
def contiguous_backward(self) -> Tensor:
"""
Inserts a contiguous operation in the backward pass.
"""
return self._apply_uop(UOp.contiguous_backward)
def log(self):
def log(self) -> Tensor:
"""
Computes the natural logarithm element-wise.
@@ -2604,7 +2608,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.log2()*math.log(2)
def log2(self):
def log2(self) -> Tensor:
"""
Computes the base-2 logarithm element-wise.
@@ -2615,7 +2620,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
def exp(self):
def exp(self) -> Tensor:
"""
Computes the exponential function element-wise.
@@ -2626,7 +2632,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.mul(1/math.log(2)).exp2()
def exp2(self):
def exp2(self) -> Tensor:
"""
Computes the base-2 exponential function element-wise.
@@ -2637,7 +2644,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
def relu(self):
def relu(self) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function element-wise.
@@ -2649,7 +2657,7 @@ class Tensor(SimpleMathTrait):
"""
return (self>0).where(self, 0)
def sigmoid(self):
def sigmoid(self) -> Tensor:
"""
Applies the Sigmoid function element-wise.
@@ -2661,7 +2669,7 @@ class Tensor(SimpleMathTrait):
"""
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
"""
Applies the Hardsigmoid function element-wise.
NOTE: default `alpha` and `beta` values is taken from torch
@@ -2675,7 +2683,7 @@ class Tensor(SimpleMathTrait):
"""
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
def sqrt(self):
def sqrt(self) -> Tensor:
"""
Computes the square root of the tensor element-wise.
@@ -2684,7 +2692,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
def rsqrt(self):
def rsqrt(self) -> Tensor:
"""
Computes the reciprocal of the square root of the tensor element-wise.
@@ -2693,7 +2702,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.sqrt().reciprocal()
def sin(self):
def sin(self) -> Tensor:
"""
Computes the sine of the tensor element-wise.
@@ -2702,7 +2712,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
def cos(self):
def cos(self) -> Tensor:
"""
Computes the cosine of the tensor element-wise.
@@ -2711,7 +2722,8 @@ class Tensor(SimpleMathTrait):
```
"""
return ((math.pi/2)-self).sin()
def tan(self):
def tan(self) -> Tensor:
"""
Computes the tangent of the tensor element-wise.
@@ -2721,7 +2733,7 @@ class Tensor(SimpleMathTrait):
"""
return self.sin() / self.cos()
def asin(self):
def asin(self) -> Tensor:
"""
Computes the inverse sine (arcsine) of the tensor element-wise.
@@ -2734,7 +2746,7 @@ class Tensor(SimpleMathTrait):
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
return self.sign() * x
def acos(self):
def acos(self) -> Tensor:
"""
Computes the inverse cosine (arccosine) of the tensor element-wise.
@@ -2744,7 +2756,7 @@ class Tensor(SimpleMathTrait):
"""
return math.pi / 2 - self.asin()
def atan(self):
def atan(self) -> Tensor:
"""
Computes the inverse tangent (arctan) of the tensor element-wise.
@@ -2765,6 +2777,7 @@ class Tensor(SimpleMathTrait):
```
"""
return self.cast(dtypes.int32).cast(self.dtype)
def ceil(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards positive infinity.
@@ -2774,6 +2787,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self > (b := self.trunc())).where(b+1, b)
def floor(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise towards negative infinity.
@@ -2783,6 +2797,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self < (b := self.trunc())).where(b-1, b)
def round(self: Tensor) -> Tensor:
"""
Rounds the tensor element-wise with rounding half to even.
@@ -2802,6 +2817,7 @@ class Tensor(SimpleMathTrait):
```
"""
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
def isnan(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
@@ -2811,6 +2827,7 @@ class Tensor(SimpleMathTrait):
```
"""
return self != self
def isfinite(self:Tensor) -> Tensor:
"""
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
@@ -2834,7 +2851,7 @@ class Tensor(SimpleMathTrait):
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
return self + (end - self) * weight
def square(self):
def square(self) -> Tensor:
"""
Squares the tensor element-wise.
Equivalent to `self*self`.
@@ -2844,7 +2861,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self*self
def clamp(self, min_=None, max_=None):
def clamp(self, min_=None, max_=None) -> Tensor:
"""
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
@@ -2856,12 +2874,14 @@ class Tensor(SimpleMathTrait):
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
ret = self.maximum(min_) if min_ is not None else self
return ret.minimum(max_) if max_ is not None else ret
def clip(self, min_=None, max_=None):
def clip(self, min_=None, max_=None) -> Tensor:
"""
Alias for `Tensor.clamp`.
"""
return self.clamp(min_, max_)
def sign(self):
def sign(self) -> Tensor:
"""
Returns the sign of the tensor element-wise.
@@ -2870,7 +2890,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
def abs(self):
def abs(self) -> Tensor:
"""
Computes the absolute value of the tensor element-wise.
@@ -2879,7 +2900,8 @@ class Tensor(SimpleMathTrait):
```
"""
return self * self.sign()
def reciprocal(self):
def reciprocal(self) -> Tensor:
"""
Compute `1/x` element-wise.
@@ -2891,7 +2913,7 @@ class Tensor(SimpleMathTrait):
# ***** activation functions *****
def elu(self, alpha=1.0):
def elu(self, alpha=1.0) -> Tensor:
"""
Applies the Exponential Linear Unit (ELU) function element-wise.
@@ -2904,7 +2926,7 @@ class Tensor(SimpleMathTrait):
"""
return self.relu() - alpha*(1-self.exp()).relu()
def celu(self, alpha=1.0):
def celu(self, alpha=1.0) -> Tensor:
"""
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
@@ -2917,7 +2939,7 @@ class Tensor(SimpleMathTrait):
"""
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
def selu(self, alpha=1.67326, gamma=1.0507):
def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
"""
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
@@ -2930,7 +2952,7 @@ class Tensor(SimpleMathTrait):
"""
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
def swish(self):
def swish(self) -> Tensor:
"""
See `.silu()`
@@ -2942,7 +2964,7 @@ class Tensor(SimpleMathTrait):
"""
return self * self.sigmoid()
def silu(self):
def silu(self) -> Tensor:
"""
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
@@ -2955,7 +2977,7 @@ class Tensor(SimpleMathTrait):
"""
return self.swish() # The SiLU function is also known as the swish function.
def relu6(self):
def relu6(self) -> Tensor:
"""
Applies the ReLU6 function element-wise.
@@ -2968,7 +2990,7 @@ class Tensor(SimpleMathTrait):
"""
return self.relu() - (self-6).relu()
def hardswish(self):
def hardswish(self) -> Tensor:
"""
Applies the Hardswish function element-wise.
@@ -2981,7 +3003,7 @@ class Tensor(SimpleMathTrait):
"""
return self * (self+3).relu6() * (1/6)
def tanh(self):
def tanh(self) -> Tensor:
"""
Applies the Hyperbolic Tangent (tanh) function element-wise.
@@ -2993,7 +3015,7 @@ class Tensor(SimpleMathTrait):
"""
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
def sinh(self):
def sinh(self) -> Tensor:
"""
Applies the Hyperbolic Sine (sinh) function element-wise.
@@ -3005,7 +3027,7 @@ class Tensor(SimpleMathTrait):
"""
return (self.exp() - self.neg().exp()) / 2
def cosh(self):
def cosh(self) -> Tensor:
"""
Applies the Hyperbolic Cosine (cosh) function element-wise.
@@ -3017,7 +3039,7 @@ class Tensor(SimpleMathTrait):
"""
return (self.exp() + self.neg().exp()) / 2
def atanh(self):
def atanh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
@@ -3029,7 +3051,7 @@ class Tensor(SimpleMathTrait):
"""
return ((1 + self)/(1 - self)).log() / 2
def asinh(self):
def asinh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
@@ -3041,7 +3063,7 @@ class Tensor(SimpleMathTrait):
"""
return (self + (self.square() + 1).sqrt()).log()
def acosh(self):
def acosh(self) -> Tensor:
"""
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
@@ -3053,7 +3075,7 @@ class Tensor(SimpleMathTrait):
"""
return (self + (self.square() - 1).sqrt()).log()
def hardtanh(self, min_val=-1, max_val=1):
def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
"""
Applies the Hardtanh function element-wise.
@@ -3065,7 +3087,7 @@ class Tensor(SimpleMathTrait):
"""
return self.clip(min_val, max_val)
def erf(self):
def erf(self) -> Tensor:
"""
Applies error function element-wise.
@@ -3079,7 +3101,7 @@ class Tensor(SimpleMathTrait):
t = 1.0 / (1.0 + 0.3275911 * self.abs())
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
def gelu(self):
def gelu(self) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
@@ -3092,7 +3114,7 @@ class Tensor(SimpleMathTrait):
"""
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
def quick_gelu(self):
def quick_gelu(self) -> Tensor:
"""
Applies the Sigmoid GELU approximation element-wise.
@@ -3104,7 +3126,7 @@ class Tensor(SimpleMathTrait):
"""
return self * (self * 1.702).sigmoid()
def leaky_relu(self, neg_slope=0.01):
def leaky_relu(self, neg_slope=0.01) -> Tensor:
"""
Applies the Leaky ReLU function element-wise.
@@ -3119,7 +3141,7 @@ class Tensor(SimpleMathTrait):
"""
return (self<0).where(neg_slope*self, self)
def mish(self):
def mish(self) -> Tensor:
"""
Applies the Mish function element-wise.
@@ -3132,7 +3154,7 @@ class Tensor(SimpleMathTrait):
"""
return self * self.softplus().tanh()
def softplus(self, beta=1):
def softplus(self, beta=1) -> Tensor:
"""
Applies the Softplus function element-wise.
@@ -3144,7 +3166,7 @@ class Tensor(SimpleMathTrait):
"""
return (1/beta) * (1 + (self*beta).exp()).log()
def softsign(self):
def softsign(self) -> Tensor:
"""
Applies the Softsign function element-wise.
@@ -3360,7 +3382,7 @@ class Tensor(SimpleMathTrait):
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:int):
def lshift(self, x:int) -> Tensor:
"""
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self << x`.
@@ -3372,7 +3394,7 @@ class Tensor(SimpleMathTrait):
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
return self.mul(2 ** x)
def rshift(self, x:int):
def rshift(self, x:int) -> Tensor:
"""
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
Equivalent to `self >> x`.
@@ -3434,7 +3456,7 @@ class Tensor(SimpleMathTrait):
t, x = self._broadcasted(x)
return t._inverse().maximum(x._inverse())._inverse()
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint):
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
"""
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
`output_i = x_i if self_i else y_i`.
@@ -3458,7 +3480,7 @@ class Tensor(SimpleMathTrait):
cond, y = cond._broadcasted(y, match_dtype=False)
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self)
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self)
def copysign(self, other) -> Tensor:
"""
@@ -3503,7 +3525,7 @@ class Tensor(SimpleMathTrait):
# ***** functional nn ops *****
def linear(self, weight:Tensor, bias:Tensor|None=None):
def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor:
"""
Applies a linear transformation to `self` using `weight` and `bias`.
@@ -3519,7 +3541,7 @@ class Tensor(SimpleMathTrait):
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
return x.add(bias) if bias is not None else x
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor:
"""
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
@@ -3592,7 +3614,7 @@ class Tensor(SimpleMathTrait):
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
# helper function commonly used for indexing
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor:
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
offset = self.ndim - self._resolve_dim(dim) - 1
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
@@ -3821,7 +3843,7 @@ class Tensor(SimpleMathTrait):
# ***** cast ops *****
def llvm_bf16_cast(self, dtype:DTypeLike):
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
# hack for devices that don't support bfloat16
assert self.dtype == dtypes.bfloat16
return self.to("LLVM").cast(dtype)
@@ -4011,8 +4033,10 @@ class Tensor(SimpleMathTrait):
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
def _metadata_wrapper(fn):
def _wrapper(*args, **kwargs):
P = ParamSpec("P")
T = TypeVar("T")
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
if _METADATA.get() is not None: return fn(*args, **kwargs)
if TRACEMETA >= 2: