mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
8
.github/workflows/benchmark.yml
vendored
8
.github/workflows/benchmark.yml
vendored
@@ -455,8 +455,8 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Insert amdgpu
|
||||
run: sudo modprobe amdgpu
|
||||
- name: Remove amdgpu
|
||||
run: sudo rmmod amdgpu || true
|
||||
- name: Symlink models and datasets
|
||||
run: |
|
||||
mkdir -p weights
|
||||
@@ -474,10 +474,6 @@ jobs:
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: setup perflevel
|
||||
run: |
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
|
||||
rocm-smi
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
|
||||
13
.github/workflows/test.yml
vendored
13
.github/workflows/test.yml
vendored
@@ -323,8 +323,8 @@ jobs:
|
||||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' README.md > README.py && PYTHONPATH=. python README.py
|
||||
- name: Run unit tests
|
||||
run: PYTHONPATH="." python -m pytest -n=auto test/unit/
|
||||
- name: Repo line count < 11300 lines
|
||||
run: MAX_LINE_COUNT=11300 python sz.py
|
||||
- name: Repo line count < 11500 lines
|
||||
run: MAX_LINE_COUNT=11500 python sz.py
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
@@ -347,7 +347,7 @@ jobs:
|
||||
|
||||
testgpuimage:
|
||||
name: 'GPU IMAGE Tests'
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
@@ -371,7 +371,7 @@ jobs:
|
||||
|
||||
testopenpilot:
|
||||
name: 'openpilot Compile Tests'
|
||||
runs-on: ubuntu-20.04
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
@@ -644,7 +644,7 @@ jobs:
|
||||
backend: [metal, llvm, cpu]
|
||||
name: MacOS (${{ matrix.backend }})
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -664,6 +664,9 @@ jobs:
|
||||
run: python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
- name: Run macOS-specific unit test
|
||||
if: matrix.backend == 'cpu'
|
||||
run: python3 -m pytest test/unit/test_disk_tensor.py::TestDiskTensor::test_copy_to_cpu_not_truncated
|
||||
|
||||
# ****** Windows Tests ******
|
||||
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -10,6 +10,7 @@ notebooks
|
||||
*.so
|
||||
*.txt
|
||||
build
|
||||
!examples/tinychat/assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/
|
||||
/dist
|
||||
*.egg-info
|
||||
/env
|
||||
|
||||
@@ -171,6 +171,7 @@ generate_amd() {
|
||||
extra/hip_gpu_driver/sdma_v6_0_0_pkt_open.h \
|
||||
extra/hip_gpu_driver/gc_11_0_0_offset.h \
|
||||
extra/hip_gpu_driver/gc_10_3_0_offset.h \
|
||||
extra/hip_gpu_driver/sienna_cichlid_ip_offset.h \
|
||||
--clang-args="-I/opt/rocm/include -x c++" \
|
||||
-o $BASE/amd_gpu.py
|
||||
|
||||
@@ -353,6 +354,12 @@ generate_am() {
|
||||
extra/amdpci/headers/amdgpu_smu.h \
|
||||
-o $BASE/am/smu_v13_0_0.py
|
||||
fixup $BASE/am/smu_v13_0_0.py
|
||||
|
||||
clang2py -k cdefstum \
|
||||
extra/amdpci/headers/hdp_6_0_0_offset.h \
|
||||
extra/amdpci/headers/hdp_6_0_0_sh_mask.h \
|
||||
-o $BASE/am/hdp_6_0_0.py
|
||||
fixup $BASE/am/hdp_6_0_0.py
|
||||
}
|
||||
|
||||
generate_webgpu() {
|
||||
|
||||
@@ -851,7 +851,9 @@ def train_step_bert(model, optimizer, scheduler, loss_scaler:float, input_ids:Te
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
return loss.realize(), global_norm.realize()
|
||||
# TODO: no to("CPU") here because it blocks and messes the python time
|
||||
Tensor.realize(loss, global_norm, optimizer.optimizers[0].lr)
|
||||
return loss, global_norm, optimizer.optimizers[0].lr
|
||||
|
||||
@TinyJit
|
||||
def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:Tensor, masked_positions:Tensor, masked_lm_ids:Tensor,
|
||||
@@ -862,7 +864,10 @@ def eval_step_bert(model, input_ids:Tensor, segment_ids:Tensor, attention_mask:T
|
||||
lm_logits, seq_relationship_logits = model(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss = \
|
||||
model.accuracy(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
return masked_lm_accuracy.realize(), seq_relationship_accuracy.realize(), masked_lm_loss.realize(), next_sentence_loss.realize()
|
||||
for t in [masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss]:
|
||||
t.to_("CPU")
|
||||
Tensor.realize(masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss)
|
||||
return masked_lm_accuracy, seq_relationship_accuracy, masked_lm_loss, next_sentence_loss
|
||||
|
||||
def train_bert():
|
||||
# NOTE: pip install tensorflow, wandb required
|
||||
@@ -1031,47 +1036,49 @@ def train_bert():
|
||||
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
|
||||
|
||||
while train_data is not None and i < train_steps and not achieved:
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
||||
if getenv("TRAIN", 1):
|
||||
Tensor.training = True
|
||||
BEAM.value = TRAIN_BEAM
|
||||
st = time.perf_counter()
|
||||
GlobalCounters.reset()
|
||||
loss, global_norm, lr = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
|
||||
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
|
||||
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
|
||||
|
||||
pt = time.perf_counter()
|
||||
pt = time.perf_counter()
|
||||
|
||||
try:
|
||||
next_data = next(train_it)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
try:
|
||||
next_data = next(train_it)
|
||||
except StopIteration:
|
||||
next_data = None
|
||||
|
||||
dt = time.perf_counter()
|
||||
dt = time.perf_counter()
|
||||
|
||||
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
|
||||
loss = loss.item()
|
||||
device_str = parameters[0].device if isinstance(parameters[0].device, str) else f"{parameters[0].device[0]} * {len(parameters[0].device)}"
|
||||
loss = loss.item()
|
||||
lr = lr.item()
|
||||
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
cl = time.perf_counter()
|
||||
if BENCHMARK: step_times.append(cl - st)
|
||||
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
||||
if WANDB:
|
||||
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
||||
tqdm.write(
|
||||
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
|
||||
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {lr:.6f} LR, "
|
||||
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
|
||||
if WANDB:
|
||||
wandb.log({"lr": lr, "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
|
||||
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
|
||||
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
|
||||
|
||||
train_data, next_data = next_data, None
|
||||
i += 1
|
||||
train_data, next_data = next_data, None
|
||||
i += 1
|
||||
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * train_steps / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
||||
if i == BENCHMARK:
|
||||
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
|
||||
estimated_total_minutes = int(median_step_time * train_steps / 60)
|
||||
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
|
||||
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
|
||||
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
|
||||
|
||||
# ** eval loop **
|
||||
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
|
||||
|
||||
11
examples/tinychat/assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css
vendored
Normal file
11
examples/tinychat/assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
/*!
|
||||
Pure v3.0.0
|
||||
Copyright 2013 Yahoo!
|
||||
Licensed under the BSD License.
|
||||
https://github.com/pure-css/pure/blob/master/LICENSE
|
||||
*/
|
||||
/*!
|
||||
normalize.css v | MIT License | https://necolas.github.io/normalize.css/
|
||||
Copyright (c) Nicolas Gallagher and Jonathan Neal
|
||||
*/
|
||||
/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}
|
||||
5
examples/tinychat/tinychat-browser/.gitignore
vendored
Normal file
5
examples/tinychat/tinychat-browser/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
net_*
|
||||
llama3-2.tiktoken
|
||||
tiktoken.js
|
||||
tiktoken_bg.wasm
|
||||
transformer*
|
||||
8
examples/tinychat/tinychat-browser/README.md
Normal file
8
examples/tinychat/tinychat-browser/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# How to build and run tinychat in browser (WebGPU and WASM)
|
||||
- `PYTHONPATH=. python examples/tinychat/tinychat-browser/compile.py`
|
||||
- `./examples/tinychat/tinychat-browser/compile_wasm.sh`
|
||||
- Prerequisite: [install emscripten](https://emscripten.org/docs/getting_started/downloads.html). This script looks for `~/emsdk/emsdk_env.sh`, adjust this based on your installation.
|
||||
- `./examples/tinychat/tinychat-browser/make_tiktoken_js.sh`
|
||||
- Prerequisite: install `npm`, `webpack`.
|
||||
- `cd examples/tinychat && python -m http.server 7776`
|
||||
- In browser: open either `localhost:7776/tinychat-browser` (WebGPU), or `localhost:7776/tinychat-browser/?backend=wasm` (WASM)
|
||||
149
examples/tinychat/tinychat-browser/compile.py
Normal file
149
examples/tinychat/tinychat-browser/compile.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os, json, hashlib, math
|
||||
from extra.export_model import export_model
|
||||
from examples.llama3 import build_transformer, Tokenizer
|
||||
from tinygrad.nn.state import get_state_dict, load_state_dict
|
||||
from tinygrad import Device, Variable, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import fetch, Context
|
||||
from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
|
||||
|
||||
def prepare_browser_chunks(model):
|
||||
# split weights into browser-friendly chunks
|
||||
state_dict = get_state_dict(model)
|
||||
del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export
|
||||
chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
|
||||
metadata = {}
|
||||
# We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
|
||||
t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
|
||||
empty_t_infos = [(v.lazydata.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
|
||||
|
||||
split_t_infos = []
|
||||
for size, name, dtype in t_infos:
|
||||
if size <= chunk_size:
|
||||
split_t_infos.append((size, name, dtype, ()))
|
||||
else: # split large weights into multiple parts
|
||||
for i in range(0, size, chunk_size):
|
||||
split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size))))
|
||||
|
||||
files = []
|
||||
# pack weights into files with FFD bin packing
|
||||
split_t_infos = sorted(split_t_infos, reverse=True)
|
||||
for info in split_t_infos:
|
||||
placed = False
|
||||
for file in files:
|
||||
if sum(i[0] for i in file) + info[0] <= chunk_size:
|
||||
if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints
|
||||
file.append(info)
|
||||
placed = True
|
||||
break
|
||||
if not placed:
|
||||
files.append([info])
|
||||
|
||||
tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"}
|
||||
for i, file in enumerate(files):
|
||||
cursor = 0
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer:
|
||||
for size, name, dtype, offsets in file:
|
||||
name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1]))
|
||||
default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]}
|
||||
weight_metadata = metadata.get(name, default)
|
||||
weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
|
||||
metadata[name] = weight_metadata
|
||||
data = bytes(state_dict[name].lazydata.base.realized.as_buffer())
|
||||
data = data if not offsets else data[offsets[0]:offsets[1]]
|
||||
writer.write(data)
|
||||
cursor += size
|
||||
|
||||
metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos})
|
||||
|
||||
for k in metadata:
|
||||
metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])]
|
||||
cursor = 0
|
||||
for i, part in enumerate(metadata[k]["parts"]):
|
||||
metadata[k]["parts"][i]["target_start_pos"] = cursor
|
||||
cursor += part["size"]
|
||||
metadata[k]["size"] = cursor
|
||||
|
||||
# compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues
|
||||
state_dict_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"state_dict": metadata, "state_dict_hash": state_dict_hash, "files": []}
|
||||
hashes = set()
|
||||
for i in range(len(files)):
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "rb") as reader:
|
||||
hash = hashlib.sha256(reader.read()).hexdigest()
|
||||
hashes.add(hash)
|
||||
metadata["files"].append({"name": f'net_part{i}.chunk', "hash": hash})
|
||||
if len(hashes) != len(files): print(f"WARNING: {len(files)} files were exported, but only {len(hashes)} are unique: something may have gone wrong")
|
||||
metadata_hash = hashlib.sha256(json.dumps(metadata, sort_keys=True).encode("utf-8")).hexdigest()
|
||||
metadata = {"metadata": metadata, "metadata_hash": metadata_hash}
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_metadata.json'), "w") as writer: json.dump(metadata, writer, indent=4)
|
||||
return metadata
|
||||
|
||||
def validate_model(model, tokenizer):
|
||||
prompt = "yo"
|
||||
toks = [tokenizer.bos_id]
|
||||
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("user") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
|
||||
toks += tokenizer.encode(prompt) + [tokenizer.special_tokens["<|eot_id|>"]]
|
||||
toks += [tokenizer.special_tokens["<|start_header_id|>"]] + tokenizer.encode("assistant") + [tokenizer.special_tokens["<|end_header_id|>"]] + tokenizer.encode("\n\n")
|
||||
start_pos = 0
|
||||
run = TinyJit(model.forward)
|
||||
for tok in toks[:-1]:
|
||||
run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).realize()
|
||||
start_pos += 1
|
||||
tok = toks[-1]
|
||||
result = ""
|
||||
expected = "How's it going?"
|
||||
while True:
|
||||
tok = run(Tensor([[tok]]), Variable("start_pos", 0, model.max_context).bind(start_pos), 0.0, 0, 0.0, 0.0, 0.0).item()
|
||||
start_pos += 1
|
||||
if tok in tokenizer.stop_tokens or len(result) > len(expected): break
|
||||
result += tokenizer.decode([tok])
|
||||
assert result == expected, f"Model validation failed, expected output: {expected}, actual output: {result}"
|
||||
|
||||
if __name__=="__main__":
|
||||
# Export BPE data for use with tiktoken.js
|
||||
tokenizer_path = fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
mergeable_ranks = load_tiktoken_bpe(str(tokenizer_path))
|
||||
bpe_path = os.path.join(os.path.dirname(__file__), "llama3-2.tiktoken")
|
||||
dump_tiktoken_bpe(mergeable_ranks, bpe_path)
|
||||
tokenizer = Tokenizer(str(tokenizer_path))
|
||||
|
||||
model_path = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-f16.gguf", "Llama-3.2-1B-Instruct-f16.gguf", subdir="llama3-1b-instruct")
|
||||
Tensor.no_grad = True
|
||||
max_context=1024
|
||||
tok = 128000
|
||||
TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P = 0.95, 0, 0.0, 0.0, 0.0
|
||||
start_pos = Variable("start_pos", 0, max_context).bind(0)
|
||||
model_input = lambda: [Tensor([[tok]]), start_pos, TEMPERATURE, TOP_K, TOP_P, ALPHA_F, ALPHA_P]
|
||||
|
||||
Device.DEFAULT="CPU"
|
||||
model = build_transformer(model_path, model_size="1B", quantize="int8", scale_dtype=dtypes.float32, device=Device.DEFAULT, max_context=max_context)
|
||||
state_dict = get_state_dict(model)
|
||||
validate_model(model, tokenizer)
|
||||
model_name = "transformer"
|
||||
|
||||
with Context(BEAM=3):
|
||||
cprog, js_wrapper = export_model(model, "wasm", *model_input(), model_name=model_name)
|
||||
# ensure consistency with exported weights
|
||||
js_wrapper = js_wrapper.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f"{model_name}.c"), "w") as f: f.write(cprog)
|
||||
with open(os.path.join(os.path.dirname(__file__), "net_clang.js"), "w") as f: f.write(js_wrapper)
|
||||
|
||||
Device.DEFAULT="WEBGPU"
|
||||
# float16 is not yet supported for dawn/Vulkan/NVIDIA stack, see: https://issues.chromium.org/issues/42251215
|
||||
# therefore for now, we used CLANG to quantize the float16 llama to int8 with float32 scales, then load to WEBGPU
|
||||
model = build_transformer(model_path, model_size="1B", quantize="int8", max_context=max_context, load_weights=False)
|
||||
load_state_dict(model, state_dict)
|
||||
# these were the same before load_state_dict
|
||||
model.output.weight, model.output.scale = model.tok_embeddings.weight, model.tok_embeddings.scale
|
||||
|
||||
validate_model(model, tokenizer)
|
||||
metadata = prepare_browser_chunks(model) # export weights to disk
|
||||
|
||||
with Context(BEAM=3):
|
||||
prg, input_sizes, output_sizes, state = export_model(model, "webgpu", *model_input(), model_name=model_name, stream_weights=True)
|
||||
# ensure consistency with exported weights
|
||||
prg = prg.replace("output.weight", "tok_embeddings.weight").replace("output.scale", "tok_embeddings.scale")
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "net.js"), "w") as f: f.write(prg)
|
||||
23
examples/tinychat/tinychat-browser/compile_wasm.sh
Executable file
23
examples/tinychat/tinychat-browser/compile_wasm.sh
Executable file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env bash
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# prereq: install emscripten: https://emscripten.org/docs/getting_started/downloads.html
|
||||
EMSCRIPTEN_PATH=~/emsdk/emsdk_env.sh
|
||||
source $EMSCRIPTEN_PATH
|
||||
step="transformer"
|
||||
initial_memory=6553600
|
||||
max_memory=1500053504
|
||||
exported_functions='["_net", "_malloc", "_free", "_set_buf"]'
|
||||
|
||||
emcc "${step}.c" \
|
||||
-O3 -msimd128 -ffast-math -flto \
|
||||
-o "${step}.js" \
|
||||
-s MODULARIZE=1 \
|
||||
-s EXPORT_ES6=1 \
|
||||
-s EXPORTED_FUNCTIONS="${exported_functions}" \
|
||||
-s ENVIRONMENT='worker' \
|
||||
-s FILESYSTEM=0 \
|
||||
-s EVAL_CTORS \
|
||||
-s ALLOW_MEMORY_GROWTH=1 \
|
||||
-s INITIAL_MEMORY="$initial_memory" \
|
||||
-s MAXIMUM_MEMORY="$max_memory"
|
||||
322
examples/tinychat/tinychat-browser/index.css
Normal file
322
examples/tinychat/tinychat-browser/index.css
Normal file
@@ -0,0 +1,322 @@
|
||||
/* define colors */
|
||||
:root {
|
||||
--primary-color: #fff;
|
||||
--secondary-color: #2a2a2a;
|
||||
--secondary-color-transparent: #ffffff66;
|
||||
--primary-bg-color: #1a1a1a;
|
||||
--foreground-color: #f0f0f0;
|
||||
}
|
||||
|
||||
main {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
|
||||
place-items: center;
|
||||
}
|
||||
|
||||
.home {
|
||||
width: 100%;
|
||||
height: 90%;
|
||||
|
||||
margin-bottom: 10rem;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 3rem;
|
||||
margin: 1rem 0;
|
||||
margin-top: 3rem;
|
||||
}
|
||||
|
||||
.histories-container-container {
|
||||
width: 100%;
|
||||
max-height: 75%;
|
||||
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.histories-container {
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
|
||||
margin: 0;
|
||||
padding: 3rem 1rem;
|
||||
}
|
||||
|
||||
.histories-start {
|
||||
height: 3rem;
|
||||
width: 100%;
|
||||
|
||||
z-index: 999;
|
||||
top: 0;
|
||||
position: absolute;
|
||||
|
||||
background: linear-gradient(
|
||||
180deg,
|
||||
var(--primary-bg-color) 0%,
|
||||
transparent 100%
|
||||
);
|
||||
}
|
||||
.histories-end {
|
||||
height: 3rem;
|
||||
width: 100%;
|
||||
|
||||
z-index: 999;
|
||||
bottom: 0;
|
||||
position: absolute;
|
||||
|
||||
background: linear-gradient(
|
||||
0deg,
|
||||
var(--primary-bg-color) 0%,
|
||||
transparent 100%
|
||||
);
|
||||
}
|
||||
|
||||
.history {
|
||||
padding: 1rem;
|
||||
width: 100%;
|
||||
max-width: 40rem;
|
||||
|
||||
background-color: var(--secondary-color);
|
||||
border-radius: 10px;
|
||||
border-left: 2px solid var(--primary-color);
|
||||
|
||||
cursor: pointer;
|
||||
|
||||
transform: translateX(calc(1px * var(--tx, 0)));
|
||||
opacity: var(--opacity, 1);
|
||||
}
|
||||
.history:hover {
|
||||
background-color: var(--secondary-color);
|
||||
}
|
||||
|
||||
.history-delete-button {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
padding: 0.5rem;
|
||||
margin: 0;
|
||||
outline: none;
|
||||
border: none;
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 0 0 0 10px;
|
||||
cursor: pointer;
|
||||
transition: 0.2s;
|
||||
}
|
||||
.history-delete-button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.messages {
|
||||
overflow-y: auto;
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
max-width: 1200px;
|
||||
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
align-items: center;
|
||||
padding-top: 1rem;
|
||||
padding-bottom: 11rem;
|
||||
}
|
||||
|
||||
.message {
|
||||
max-width: 75%;
|
||||
padding: 0.5rem 1rem;
|
||||
border-radius: 20px;
|
||||
}
|
||||
.message-role-assistant {
|
||||
background-color: var(--secondary-color);
|
||||
margin-right: auto;
|
||||
color: #fff;
|
||||
}
|
||||
.message-role-user {
|
||||
margin-left: auto;
|
||||
background-color: var(--primary-color);
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.message > pre {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.hljs {
|
||||
width: 100%;
|
||||
position: relative;
|
||||
border-radius: 10px;
|
||||
/* wrap code blocks */
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
/* put clipboard button in the top right corner of the code block */
|
||||
.clipboard-button {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
padding: 0.5rem;
|
||||
margin: 0;
|
||||
outline: none;
|
||||
border: none;
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 0 0 0 10px;
|
||||
cursor: pointer;
|
||||
transition: 0.2s;
|
||||
}
|
||||
.clipboard-button:hover {
|
||||
background-color: var(--secondary-color);
|
||||
padding: 0.75rem;
|
||||
}
|
||||
|
||||
.input-container {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
|
||||
/* linear gradient from background-color to transparent on the top */
|
||||
background: linear-gradient(
|
||||
0deg,
|
||||
var(--primary-bg-color) 55%,
|
||||
transparent 100%
|
||||
);
|
||||
|
||||
width: 100%;
|
||||
max-width: 1200px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
z-index: 999;
|
||||
}
|
||||
|
||||
.input-performance {
|
||||
margin-top: 4rem;
|
||||
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
gap: 1rem;
|
||||
}
|
||||
|
||||
.input-performance-point {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
place-items: center;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
.input-performance-point > p {
|
||||
height: 1rem;
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
.input {
|
||||
width: 90%;
|
||||
min-height: 3rem;
|
||||
flex-shrink: 0;
|
||||
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
|
||||
align-items: flex-end;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.input-form {
|
||||
width: 100%;
|
||||
padding: 1rem;
|
||||
min-height: 3rem;
|
||||
max-height: 8rem;
|
||||
|
||||
background-color: var(--secondary-color);
|
||||
color: var(--foreground-color);
|
||||
border-radius: 10px;
|
||||
border: none;
|
||||
resize: none;
|
||||
outline: none;
|
||||
}
|
||||
.mobile .input-form { /* prevent auto-zoom on touching prompt box */
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.input-button {
|
||||
height: 3rem;
|
||||
width: 4rem;
|
||||
|
||||
background-color: var(--primary-color);
|
||||
color: var(--secondary-color);
|
||||
border-radius: 10px;
|
||||
padding: 0.5rem;
|
||||
cursor: pointer;
|
||||
}
|
||||
.input-button:hover {
|
||||
background-color: var(--secondary-color-transparent);
|
||||
}
|
||||
.input-button:disabled {
|
||||
background-color: var(--secondary-color);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* wrap text */
|
||||
p {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
/* fonts */
|
||||
.megrim-regular {
|
||||
font-family: monospace;
|
||||
font-weight: 400;
|
||||
font-style: normal;
|
||||
}
|
||||
|
||||
.monospace {
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
.loading-bar {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
width: 100%;
|
||||
min-height: 3rem;
|
||||
margin-bottom: 2rem;
|
||||
}
|
||||
|
||||
.loading-text {
|
||||
color: var(--foreground-color);
|
||||
font-size: 1rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
#progress-percentage {
|
||||
color: var(--foreground-color);
|
||||
font-size: 1rem;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.progress-bar {
|
||||
flex-grow: 1;
|
||||
height: 0.5rem;
|
||||
background-color: var(--secondary-color);
|
||||
border-radius: 5px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.progress {
|
||||
width: 0%;
|
||||
height: 100%;
|
||||
background-color: var(--primary-color);
|
||||
transition: width 0.2s ease-in-out;
|
||||
}
|
||||
182
examples/tinychat/tinychat-browser/index.html
Normal file
182
examples/tinychat/tinychat-browser/index.html
Normal file
@@ -0,0 +1,182 @@
|
||||
<!DOCTYPE html>
|
||||
|
||||
<head>
|
||||
<title>tinychat</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<link rel="icon" href="../favicon.svg" type="image/svg+xml">
|
||||
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpine-collective/toolkit@1.0.2/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/intersect@3.x.x/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/cdn.jsdelivr.net/npm/@alpinejs/focus@3.x.x/dist/cdn.min.js"></script>
|
||||
<script defer src="../assets/unpkg.com/@marcreichel/alpine-autosize@1.3.x/dist/alpine-autosize.min.js"></script>
|
||||
<script defer src="../assets/unpkg.com/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
|
||||
<script src="../assets/unpkg.com/dompurify@3.1.5/dist/purify.min.js"></script>
|
||||
<script src="../assets/unpkg.com/marked@13.0.0/marked.min.js"></script>
|
||||
<script src="../assets/unpkg.com/marked-highlight@2.1.2/lib/index.umd.js"></script>
|
||||
<script src="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/highlight.min.js"></script>
|
||||
|
||||
<script src="index.js"></script>
|
||||
|
||||
<link rel="stylesheet" href="../assets/cdn.jsdelivr.net/npm/purecss@3.0.0/build/base-min.css">
|
||||
<link rel="stylesheet" href="../assets/cdnjs.cloudflare.com/ajax/libs/font-awesome/6.5.2/css/all.min.css"
|
||||
integrity="sha512-SnH5WK+bZxgPHs44uWIX+LLJAJ9/2PkPKZ5QiAj6Ta86w+fsb2TkcmfRyVX3pBnMFcV7oQPJkl9QevSCWr3W6A=="
|
||||
crossorigin="anonymous" referrerpolicy="no-referrer" />
|
||||
<link rel="stylesheet" href="../assets/unpkg.com/@highlightjs/cdn-assets@11.9.0/styles/vs2015.min.css">
|
||||
|
||||
<link rel="stylesheet" href="index.css">
|
||||
<link rel="stylesheet" href="../common.css">
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<main x-data="state" x-init="console.log(endpoint)">
|
||||
<div class="home centered" x-show="home === 0" x-transition x-effect="
|
||||
$refs.inputForm.focus();
|
||||
if (home === 1) setTimeout(() => home = 2, 100);
|
||||
if (home === -1) setTimeout(() => home = 0, 100);
|
||||
" @popstate.window="
|
||||
if (home === 2) {
|
||||
cancelGeneration = true;
|
||||
if (maxContextReached) generating = false;
|
||||
if (!generating) cstate = { time: null, messages: [] };
|
||||
home = -1;
|
||||
time_till_first = 0;
|
||||
tokens_per_second = 0;
|
||||
total_tokens = 0;
|
||||
}
|
||||
">
|
||||
<h1 class="title megrim-regular">tinychat</h1>
|
||||
<div class="histories-container-container">
|
||||
<template x-if="histories.length">
|
||||
<div class="histories-start"></div>
|
||||
</template>
|
||||
<div class="histories-container" x-intersect="
|
||||
$el.scrollTo({ top: 0, behavior: 'smooth' });
|
||||
">
|
||||
<template x-for="_state in histories.toSorted((a, b) => b.time - a.time)">
|
||||
<div x-data="{ otx: 0, trigger: 75 }" class="history" @click="
|
||||
cstate = _state;
|
||||
updateTotalTokens(cstate.messages);
|
||||
home = 1;
|
||||
// ensure that going back in history will go back to home
|
||||
window.history.pushState({}, '', window.TINYCHAT_ROOT || '/');
|
||||
" @touchstart="
|
||||
otx = $event.changedTouches[0].clientX;
|
||||
" @touchmove="
|
||||
$el.style.setProperty('--tx', $event.changedTouches[0].clientX - otx);
|
||||
$el.style.setProperty('--opacity', 1 - (Math.abs($event.changedTouches[0].clientX - otx) / trigger));
|
||||
" @touchend="
|
||||
if (Math.abs($event.changedTouches[0].clientX - otx) > trigger) removeHistory(_state);
|
||||
$el.style.setProperty('--tx', 0);
|
||||
$el.style.setProperty('--opacity', 1);
|
||||
">
|
||||
<h3 x-text="new Date(_state.time).toLocaleString()"></h3>
|
||||
<p x-text="$truncate(_state.messages[0].content, 80)"></p>
|
||||
<!-- delete button -->
|
||||
<button class="history-delete-button" @click.stop="removeHistory(_state);">
|
||||
<i class=" fas fa-trash"></i>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
<template x-if="histories.length">
|
||||
<div class="histories-end"></div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
<div x-ref="messages" class="messages" x-init="
|
||||
$watch('cstate', value => {
|
||||
$el.innerHTML = '';
|
||||
value.messages.forEach(({ role, content }) => {
|
||||
const div = document.createElement('div');
|
||||
div.className = `message message-role-${role}`;
|
||||
try {
|
||||
div.innerHTML = DOMPurify.sanitize(marked.parse(content));
|
||||
} catch (e) {
|
||||
console.log(content);
|
||||
console.error(e);
|
||||
}
|
||||
|
||||
// add a clipboard button to all code blocks
|
||||
const codeBlocks = div.querySelectorAll('.hljs');
|
||||
codeBlocks.forEach(codeBlock => {
|
||||
const button = document.createElement('button');
|
||||
button.className = 'clipboard-button';
|
||||
button.innerHTML = '<i class=\'fas fa-clipboard\'></i>';
|
||||
button.onclick = () => {
|
||||
// navigator.clipboard.writeText(codeBlock.textContent);
|
||||
const range = document.createRange();
|
||||
range.setStartBefore(codeBlock);
|
||||
range.setEndAfter(codeBlock);
|
||||
window.getSelection()?.removeAllRanges();
|
||||
window.getSelection()?.addRange(range);
|
||||
document.execCommand('copy');
|
||||
window.getSelection()?.removeAllRanges();
|
||||
|
||||
button.innerHTML = '<i class=\'fas fa-check\'></i>';
|
||||
setTimeout(() => button.innerHTML = '<i class=\'fas fa-clipboard\'></i>', 1000);
|
||||
};
|
||||
codeBlock.appendChild(button);
|
||||
});
|
||||
|
||||
$el.appendChild(div);
|
||||
});
|
||||
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
});
|
||||
" x-intersect="
|
||||
$el.scrollTo({ top: $el.scrollHeight, behavior: 'smooth' });
|
||||
" x-show="home === 2" x-transition>
|
||||
</div>
|
||||
<div class="input-container">
|
||||
<div class="input-performance">
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="(time_till_first / 1000).toFixed(2)"></p>
|
||||
<p class="megrim-regular">SEC TO FIRST TOKEN</p>
|
||||
</span>
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="tokens_per_second.toFixed(1)"></p>
|
||||
<p class="megrim-regular">TOKENS/SEC</p>
|
||||
</span>
|
||||
<span class="input-performance-point">
|
||||
<p class="monospace" x-text="total_tokens"></p>
|
||||
<p class="megrim-regular">TOKENS</p>
|
||||
</span>
|
||||
</div>
|
||||
<div class="loading-bar" x-show="loadingMessage !== ''">
|
||||
<p class="loading-text" id="loading-message">Loading:</p>
|
||||
<span id="progress-percentage">0%</span>
|
||||
<div class="progress-bar">
|
||||
<div class="progress"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="input" x-show="loadingMessage === ''">
|
||||
<textarea x-ref="inputForm" id="input-form" class="input-form" autofocus rows=1 x-autosize
|
||||
:placeholder="generating ? placeholderText : 'Say something'" :disabled="generating" @input="
|
||||
home = (home === 0) ? 1 : home
|
||||
if (cstate.messages.length === 0 && $el.value === '') home = -1;
|
||||
|
||||
if ($el.value !== '') {
|
||||
const messages = [...cstate.messages];
|
||||
messages.push({ role: 'user', content: $el.value });
|
||||
updateTotalTokens(messages);
|
||||
} else {
|
||||
if (cstate.messages.length === 0) total_tokens = 0;
|
||||
else updateTotalTokens(cstate.messages);
|
||||
}
|
||||
" x-effect="
|
||||
console.log(generating);
|
||||
if (!generating) $nextTick(() => {
|
||||
$el.focus();
|
||||
setTimeout(() => $refs.messages.scrollTo({ top: $refs.messages.scrollHeight, behavior: 'smooth' }), 100);
|
||||
});
|
||||
" @keydown.enter="await handleEnter($event)" @keydown.escape.window="$focus.focus($el)"></textarea>
|
||||
<button class="input-button" :disabled="generating" @click="await handleSend()">
|
||||
<i class="fas" :class="generating ? 'fa-spinner fa-spin' : 'fa-paper-plane'"></i>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
927
examples/tinychat/tinychat-browser/index.js
Normal file
927
examples/tinychat/tinychat-browser/index.js
Normal file
@@ -0,0 +1,927 @@
|
||||
window.TINYCHAT_ROOT = "/tinychat-browser/";
|
||||
const queryParams = new URLSearchParams(window.location.search);
|
||||
const normalizedParams = Object.fromEntries([...queryParams].map(([key, value]) => [key.toUpperCase(), value.toUpperCase()]));
|
||||
window.BACKEND = (normalizedParams["BACKEND"] === "WASM") ? "WASM" : "WebGPU";
|
||||
const isMobileAgent = /Mobi|Android|iPhone|iPad|iPod/i.test(navigator.userAgent);
|
||||
const hasTouchScreen = 'ontouchstart' in window || navigator.maxTouchPoints > 0;
|
||||
window.isMobile = isMobileAgent || hasTouchScreen;
|
||||
if (window.isMobile) document.documentElement.classList.add('mobile'); // prevent annoying auto-zoom when entering prompt on mobile
|
||||
// MODEL_BASE_URL is where the weights are hosted, WEBGPU_EXPORT is the JS-wrapped WebGPU code exported from tinygrad
|
||||
window.PC_MODEL_BASE_URL = ".";
|
||||
window.PC_WEBGPU_EXPORT = './net.js'
|
||||
window.PC_MAX_CONTEXT = 1024;
|
||||
window.MOBILE_MODEL_BASE_URL = ".";
|
||||
window.MOBILE_WEBGPU_EXPORT = './net.js'
|
||||
window.MOBILE_MAX_CONTEXT = 1024;
|
||||
|
||||
const tiktokenReady = (async () => {
|
||||
const { init, get_encoding, Tiktoken, load } = await import('./tiktoken.js');
|
||||
window.Tiktoken = Tiktoken;
|
||||
window.tiktokenInit = init;
|
||||
window.tiktokenGetEncoding = get_encoding;
|
||||
window.tiktokenLoad = load;
|
||||
})();
|
||||
|
||||
async function getDevice() {
|
||||
let adapter;
|
||||
try {
|
||||
adapter = await navigator.gpu.requestAdapter();
|
||||
if (!adapter) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
|
||||
throw new Error("No WebGPU adapter found");
|
||||
}
|
||||
} catch(error) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU not enabled):";
|
||||
throw error;
|
||||
}
|
||||
const requiredLimits = {};
|
||||
const maxBufferSize = 322122544;
|
||||
requiredLimits.maxStorageBufferBindingSize = maxBufferSize;
|
||||
requiredLimits.maxBufferSize = maxBufferSize;
|
||||
requiredLimits.maxComputeInvocationsPerWorkgroup = 512; // may need to vary based on what the WEBGPU backend produces
|
||||
|
||||
try {
|
||||
return await adapter.requestDevice({ requiredLimits });
|
||||
} catch(error) {
|
||||
this.loadingMessage = "Loading WASM (WebGPU error):";
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
// copied from examples/webgpu/stable_diffusion/index.html
|
||||
function initDb() {
|
||||
return new Promise((resolve, reject) => {
|
||||
let db;
|
||||
const request = indexedDB.open('tinydb', 1);
|
||||
request.onerror = (event) => {
|
||||
console.error('Database error:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
db = event.target.result;
|
||||
console.log("Db initialized.");
|
||||
resolve(db);
|
||||
};
|
||||
|
||||
request.onupgradeneeded = (event) => {
|
||||
db = event.target.result;
|
||||
if (!db.objectStoreNames.contains('tensors')) {
|
||||
db.createObjectStore('tensors', { keyPath: 'id' });
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// copied from examples/webgpu/stable_diffusion/index.html
|
||||
function readTensorFromDb(db, id) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readonly');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.get(id);
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while reading tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = (event) => {
|
||||
const result = event.target.result;
|
||||
if (result) {
|
||||
resolve(result);
|
||||
} else {
|
||||
resolve(null);
|
||||
}
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor retrieve failed: ', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function getAllKeysFromDb(db) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {resolve([]);}
|
||||
const transaction = db.transaction(['tensors'], 'readonly');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.getAllKeys();
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while reading IndexedDB keys: " + event.target.error);
|
||||
resolve([]);
|
||||
};
|
||||
request.onsuccess = function (event) {resolve(event.target.result);};
|
||||
request.onerror = (event) => {
|
||||
console.error('Retrieval of IndexedDB keys failed: ', event.target.error);
|
||||
resolve([]);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// modified from examples/webgpu/stable_diffusion/index.html
|
||||
function saveTensorToDb(db, id, tensor) {
|
||||
return readTensorFromDb(db, id).then((result) => {
|
||||
if (!result) {
|
||||
new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
resolve(null);
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.put({ id: id, content: tensor });
|
||||
|
||||
transaction.onabort = (event) => {
|
||||
console.log("Transaction error while saving tensor: " + event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log('Tensor saved successfully.');
|
||||
resolve();
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor save failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}).catch(()=> null);
|
||||
}
|
||||
|
||||
function deleteTensorFromDb(db, id) {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (db == null) {
|
||||
console.error("Database is not initialized.");
|
||||
resolve(null);
|
||||
return;
|
||||
}
|
||||
|
||||
const transaction = db.transaction(['tensors'], 'readwrite');
|
||||
const store = transaction.objectStore('tensors');
|
||||
const request = store.delete(id);
|
||||
|
||||
transaction.oncomplete = () => {
|
||||
console.log(`Tensor with ID '${id}' deleted successfully.`);
|
||||
resolve();
|
||||
};
|
||||
|
||||
transaction.onerror = (event) => {
|
||||
console.error("Transaction error while deleting tensor:", event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onerror = (event) => {
|
||||
console.error('Tensor deletion failed:', event.target.error);
|
||||
resolve(null);
|
||||
};
|
||||
|
||||
request.onsuccess = () => {
|
||||
console.log(`Delete request for tensor with ID '${id}' succeeded.`);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function makeProgress(total) {
|
||||
let acc = 0;
|
||||
const ret = function progress(amount, message) {
|
||||
if (amount >= 0) { // allow updating message only
|
||||
acc += amount;
|
||||
const percentage = total ? Math.trunc((acc / total) * 100) : 0;
|
||||
document.querySelector('.progress').style.width = `${percentage}%`;
|
||||
document.getElementById('progress-percentage').textContent = `${percentage}%`;
|
||||
}
|
||||
if (message) {
|
||||
this.loadingMessage = message;
|
||||
document.getElementById('loading-message').textContent = this.loadingMessage;
|
||||
}
|
||||
}.bind(this);
|
||||
ret.total = total;
|
||||
return ret;
|
||||
}
|
||||
|
||||
function sendMessageToWorker(worker, message) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const onMessage = (event) => {
|
||||
resolve(event.data);
|
||||
worker.removeEventListener('message', onMessage);
|
||||
worker.removeEventListener('error', onError);
|
||||
};
|
||||
|
||||
const onError = (error) => {
|
||||
reject(error);
|
||||
worker.removeEventListener('message', onMessage);
|
||||
worker.removeEventListener('error', onError);
|
||||
};
|
||||
|
||||
worker.addEventListener('message', onMessage);
|
||||
worker.addEventListener('error', onError);
|
||||
|
||||
if (message.header === "token") worker.postMessage(message.data);
|
||||
else if (message.header === "load_state_dict") {
|
||||
if (message.data === "done") worker.postMessage(message.data);
|
||||
else worker.postMessage(message.data, message.data.map(file => file.bytes.buffer));
|
||||
}
|
||||
else if (message.header === "init") worker.postMessage("init");
|
||||
});
|
||||
}
|
||||
|
||||
async function load_state_dict (data, device, progress) {
|
||||
let state_dict = data.metadata.state_dict;
|
||||
let completed = 0;
|
||||
|
||||
// modified from examples/webgpu/stable_diffusion/index.html getProgressDlForPart
|
||||
const loadPart = async (part) => {
|
||||
const response = await fetch(part);
|
||||
const res = new Response(new ReadableStream({
|
||||
async start(controller) {
|
||||
const reader = response.body.getReader();
|
||||
for (;;) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
progress(value.byteLength);
|
||||
controller.enqueue(value);
|
||||
}
|
||||
controller.close();
|
||||
},
|
||||
}));
|
||||
|
||||
return res.arrayBuffer();
|
||||
};
|
||||
|
||||
let db = await initDb();
|
||||
|
||||
const getPart = async(filename, hash) => {
|
||||
let part = await readTensorFromDb(db, hash);
|
||||
|
||||
if (part) {
|
||||
console.log(`Cache hit: ${filename}, hash: ${hash}`);
|
||||
progress(part.content.byteLength);
|
||||
return Promise.resolve(part.content);
|
||||
} else {
|
||||
console.log(`Cache miss: ${filename}, hash: ${hash}`);
|
||||
return loadPart(`${window.MODEL_BASE_URL}/${filename}`);
|
||||
}
|
||||
}
|
||||
|
||||
const correctHashes = data.metadata.files.map(file => file.hash)
|
||||
// delete unused cached buffers to free disk space -- if we update weights, user will otherwise have obsolete cached buffers
|
||||
const dbKeys = await getAllKeysFromDb(db);
|
||||
const correctHashesSet = new Set(correctHashes);
|
||||
const notInCorrectHashes = dbKeys.filter(key => !correctHashesSet.has(key));
|
||||
// await these right before starting to save new stuff
|
||||
const deletionPromises = notInCorrectHashes.map(async (hash) => deleteTensorFromDb(db, hash));
|
||||
|
||||
// instantiates empty weight buffers on WebGPU, attaches buffers to state_dict
|
||||
let model;
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
//model = await transformer().setup(device, state_dict, progress);
|
||||
model = await transformer.setupNet(device, state_dict);
|
||||
progress(0.15 * progress.total);
|
||||
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
progress(0.02 * progress.total);
|
||||
model = new Worker(`./worker.js?version=${Date.now()}`);
|
||||
await sendMessageToWorker(model, {header: "init"});
|
||||
progress(0.02 * progress.total);
|
||||
progress(0.11 * progress.total);
|
||||
}
|
||||
|
||||
const downloaded = [];
|
||||
const triggerChainDownload = async (toDownload) => {
|
||||
const numDownloaders = window.isMobile ? 4 : toDownload.length; // TODO: dynamically base this on DL file size? current assumption is 16 MiB chunks
|
||||
|
||||
const chainDownload = async() => {
|
||||
const file = toDownload.shift();
|
||||
loadPart(`${window.MODEL_BASE_URL}/${file.name}`) // triggers download
|
||||
.then(async (arraybuf) => {
|
||||
downloaded.push({ ...file, bytes: new Uint8Array(arraybuf)});
|
||||
// pause downloads if further processing is a bottleneck
|
||||
while (toDownload.length && downloaded.length >= numDownloaders) await new Promise(resolve => setTimeout(resolve, 5));
|
||||
if (toDownload.length && downloaded.length < numDownloaders) chainDownload(); // start next download
|
||||
})
|
||||
}
|
||||
for (let i=0; i<numDownloaders; i++) if (toDownload.length) chainDownload();
|
||||
}
|
||||
|
||||
const loadFileToStateDict = async(file) => {
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
for (const part of file.parts) {
|
||||
if (part.empty) continue;
|
||||
part.bytes = (part.size === file.bytes.length) ? file.bytes : file.bytes.slice(part.file_start_pos, part.file_start_pos + part.size);
|
||||
device.queue.writeBuffer(state_dict[part.key].bytes, part.target_start_pos, part.bytes); // improves stability over mappedAtCreation writing
|
||||
part.bytes = null;
|
||||
}
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
await sendMessageToWorker(model, {header: "load_state_dict", data: [file]});
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
|
||||
if (window.BACKEND === "WebGPU") { // contiguous loading not needed for WebGPU stability
|
||||
const files = data.tensor_file_groups.flatMap(obj => obj.files);
|
||||
data.tensor_file_groups = [{contiguous: false, files: files}];
|
||||
}
|
||||
|
||||
for (const group of data.tensor_file_groups) {
|
||||
const contiguous = group.contiguous;
|
||||
const files = group.files;
|
||||
const tensor_file_indices = files.map(file => file.index);
|
||||
const contiguousFiles = [];
|
||||
const fileHashes = new Set(files.map(file => file.hash));
|
||||
const cachedFileHashes = new Set(dbKeys.filter(key => fileHashes.has(key)));
|
||||
const cachedFiles = files.filter(file => cachedFileHashes.has(file.hash));
|
||||
const toDownload = files.filter(file => !cachedFileHashes.has(file.hash));
|
||||
triggerChainDownload(toDownload);
|
||||
|
||||
const loadDelay = 5;
|
||||
await Promise.all(deletionPromises);
|
||||
|
||||
while (completed < files.length) {
|
||||
const start = performance.now();
|
||||
// prioritize files from downloaded queue, so we can continue downloading more files
|
||||
if (downloaded.length) {
|
||||
const file = downloaded.shift();
|
||||
await saveTensorToDb(db, file.hash, file.bytes); // for wasm, must await to prevent race between indexedDB and transfer to worker
|
||||
if (!contiguous) await loadFileToStateDict(file);
|
||||
else contiguousFiles.push(file);
|
||||
completed += 1;
|
||||
}
|
||||
else if (!downloaded.length && cachedFiles.length) {
|
||||
const file = cachedFiles.shift();
|
||||
file.bytes = await getPart(file.name, file.hash); // reads data from IndexedDB
|
||||
if (!contiguous) await loadFileToStateDict(file);
|
||||
else contiguousFiles.push(file);
|
||||
completed += 1;
|
||||
}
|
||||
const end = performance.now();
|
||||
const elapsed = end - start;
|
||||
if (elapsed < loadDelay) await new Promise(resolve => setTimeout(resolve, loadDelay - elapsed));
|
||||
}
|
||||
if (contiguous) {
|
||||
const orderMap = tensor_file_indices.reduce((acc, id, index) => {acc[id] = index; return acc;}, {});
|
||||
contiguousFiles.sort((a, b) => orderMap[a.index] - orderMap[b.index]); // glue files together in the right order
|
||||
await sendMessageToWorker(model, {header: "load_state_dict", data: contiguousFiles});
|
||||
}
|
||||
completed = 0;
|
||||
}
|
||||
|
||||
// initialize empty kv_caches, which were part of exported model's state_dict, but which we didn't want to package/download
|
||||
if (window.BACKEND === "WASM") {
|
||||
for (const [k, v] of Object.entries(state_dict).filter(([_, v]) => v.empty === true)) {
|
||||
v.parts[0].file_start_pos = 0;
|
||||
const file = { parts: v.parts, size: v.size, bytes: new Uint8Array(v.size).fill(0) };
|
||||
await loadFileToStateDict(file);
|
||||
}
|
||||
}
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
document.addEventListener("alpine:init", () => {
|
||||
Alpine.data("state", () => ({
|
||||
// loadingMessage updates the user on page load progress, including weights download and decompression
|
||||
// if loadingMessage is not '', then prompt box will be hidden: this is default behavior on page load
|
||||
placeholderText: "Generating...",
|
||||
loadingMessage: `Loading ${window.BACKEND} model:`,
|
||||
// model
|
||||
nets: {},
|
||||
tokenizer: null,
|
||||
max_context: 1024,
|
||||
lastSeenToks: [],
|
||||
|
||||
progress: null,
|
||||
|
||||
async init() {
|
||||
var device = null;
|
||||
var webgpuErrorMessage = null;
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
try {
|
||||
device = await getDevice.call(this);
|
||||
console.log("WebGPU device initialized");
|
||||
} catch (error) {
|
||||
window.BACKEND = "WASM";
|
||||
console.log(`error: ${error}\nFailed to launch WebGPU. Loading WASM model instead...`); // return;
|
||||
webgpuErrorMessage = this.loadingMessage;
|
||||
}
|
||||
}
|
||||
|
||||
window.MODEL_BASE_URL = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MODEL_BASE_URL : window.MOBILE_MODEL_BASE_URL;
|
||||
this.max_context = (window.BACKEND === "WebGPU" && !window.isMobile) ? window.PC_MAX_CONTEXT : window.MOBILE_MAX_CONTEXT;
|
||||
|
||||
const kernelsReady = (async () => {
|
||||
if (window.BACKEND === "WASM") {var exports = await import(`./net_clang.js?version=${Date.now()}`);}
|
||||
else if (window.BACKEND === "WebGPU" && !window.isMobile) {var exports = await import(`${PC_WEBGPU_EXPORT}?version=${Date.now()}`);}
|
||||
else if (window.BACKEND === "WebGPU" && window.isMobile) {var exports = await import(`${MOBILE_WEBGPU_EXPORT}?version=${Date.now()}`);}
|
||||
self.transformer = exports.default;
|
||||
})();
|
||||
|
||||
const response = await fetch(`${window.MODEL_BASE_URL}/net_metadata.json`);
|
||||
// TODO: cache metadata (and everything else, including tokenizer)
|
||||
// TODO: use service worker to reload page when offline
|
||||
const data = await response.json();
|
||||
data.metadata.files = data.metadata.files.map((file, index) => ({...file, index}));
|
||||
const state_dict = data.metadata.state_dict;
|
||||
|
||||
/*
|
||||
- allocating memory to WASM on mobile has longstanding issues: https://github.com/WebAssembly/design/issues/1397
|
||||
|
||||
- the below pattern, while yielding a succesfully-functioning model when it doesn't crash, causes regular crashes on iOS Safari (iphone 15 iOS 18.3):
|
||||
- call WASM malloc (to fit all tensors, or one per tensor) for all tensors up front, then load tensor byte chunks into the buffers in random order
|
||||
|
||||
- the below pattern has been stable on iOS Safari (iphone 15 iOS 18.3):
|
||||
- call only one WASM malloc at a time before filling the allocated bytes, as small as possible (malloc up to 256 MiB has been tested)
|
||||
- fill the malloc'd memory in linear order from start to end (what has been tested is calling wasm.HEAPU8.set on 16 MiB chunks from start to end)
|
||||
- use ALLOW_MEMORY_GROWTH=1 in wasm compilation, minimize initial memory
|
||||
|
||||
- additional considerations affecting loading design, for WASM:
|
||||
- it seems that copying bytes into wasm memory cannot be zero-copy without sharedarraybuffer, which isn't currently used due to increased hosting complexity
|
||||
- non-zero copies create memory pressure, which is not reliably capped because of lack of control over garbage collection
|
||||
- to minimize peak memory pressure if GC is delayed, we process (i.e. download + copy into WASM) large tensors (> 16 MiB) one at a time, in descending size order
|
||||
*/
|
||||
data.tensor_file_groups = []; // see above: for WASM, limit processing of multi-file Tensors to one at a time, in descending order based on Tensor size
|
||||
const unsplit_tensors = [];
|
||||
const sortedEntries = Object.entries(state_dict).sort(([, objA], [, objB]) => objB.size - objA.size);
|
||||
|
||||
let totalSize = 0;
|
||||
const seen = new Set();
|
||||
for (const [k,v] of sortedEntries) {
|
||||
const files_in_tensor = [];
|
||||
for (const part of v.parts) {
|
||||
part.key = k;
|
||||
if (part.empty) state_dict[k].empty = true; // assumes no other parts of this weight exist and are non-empty
|
||||
else {
|
||||
const file = data.metadata.files[part.file];
|
||||
if (!seen.has(file.index)) {
|
||||
seen.add(file.index);
|
||||
files_in_tensor.push(file);
|
||||
}
|
||||
totalSize += part.size;
|
||||
part.dtype = v.dtype;
|
||||
if (!data.metadata.files[part.file].parts) data.metadata.files[part.file].parts = [];
|
||||
data.metadata.files[part.file].size ??= 0;
|
||||
data.metadata.files[part.file].size += part.size;
|
||||
data.metadata.files[part.file].parts.push(part);
|
||||
}
|
||||
}
|
||||
if (files_in_tensor.length > 1) data.tensor_file_groups.push({contiguous: true, files: files_in_tensor}); // [tensorN_file0, tensorN_file1, ...]
|
||||
else if (files_in_tensor.length > 0) unsplit_tensors.push(files_in_tensor[0]);
|
||||
}
|
||||
data.tensor_file_groups.push({contiguous: false, files: unsplit_tensors});
|
||||
|
||||
data.totalSize = totalSize;
|
||||
totalSize = totalSize / 0.8; // give space in progress bar for initializing model bufs, and tokenizer
|
||||
this.progress = makeProgress.call(this, totalSize); // creates closure with totalSize
|
||||
|
||||
try {
|
||||
this.progress(0.01 * totalSize, "Loading tokenizer:");
|
||||
const wasmResponse = await fetch(`${window.MODEL_BASE_URL}/tiktoken_bg.wasm`);
|
||||
this.progress(0.01 * totalSize);
|
||||
const wasmBytes = await wasmResponse.arrayBuffer();
|
||||
await tiktokenReady;
|
||||
await window.tiktokenInit((imports) => WebAssembly.instantiate(wasmBytes, imports));
|
||||
this.progress(0.01 * totalSize);
|
||||
|
||||
this.tokenizer = await createTokenizer(`${window.MODEL_BASE_URL}/llama3-2.tiktoken`);
|
||||
const tokenizer_works = (new TextDecoder().decode(this.tokenizer.decode(this.tokenizer.encode("hello world"))) === "hello world");
|
||||
console.log("tokenizer works:", tokenizer_works)
|
||||
this.progress(0.01 * totalSize);
|
||||
} catch (error) {this.progress(-1, `Error launching tokenizer: ${error}`); console.log(error); return;}
|
||||
|
||||
try {
|
||||
const loadModelMessage = (webgpuErrorMessage) ? webgpuErrorMessage : `Loading ${window.BACKEND} model:`
|
||||
this.progress(0, loadModelMessage);
|
||||
await kernelsReady;
|
||||
const model = await load_state_dict(data, device, this.progress);
|
||||
|
||||
if (window.BACKEND === "WebGPU") {
|
||||
this.nets = {"transformer": model};
|
||||
}
|
||||
else if (window.BACKEND === "WASM") {
|
||||
const msg = await sendMessageToWorker(model, {header: "load_state_dict", data: "done"});
|
||||
this.nets = {"transformer": async (tok, start_pos) => sendMessageToWorker(model, {header: "token", data: [tok, start_pos]})};
|
||||
}
|
||||
this.progress(0.01 * totalSize, `Launching ${window.BACKEND} model:`);
|
||||
this.loadingMessage = ""; // Triggers removal of loading bar, display of prompt box
|
||||
} catch (error) {this.progress(-1, `Error launching model: ${error}`); console.log(error); return;}
|
||||
},
|
||||
|
||||
// current state
|
||||
cstate: {
|
||||
time: null,
|
||||
messages: [],
|
||||
},
|
||||
|
||||
// historical state
|
||||
histories: JSON.parse(localStorage.getItem("histories")) || [],
|
||||
|
||||
home: 0,
|
||||
generating: false,
|
||||
maxContextReached: false,
|
||||
cancelGeneration: false,
|
||||
endpoint: `${window.location.origin}/v1`,
|
||||
|
||||
// performance tracking
|
||||
time_till_first: 0,
|
||||
tokens_per_second: 0,
|
||||
total_tokens: 0,
|
||||
max_context: 0,
|
||||
|
||||
removeHistory(cstate) {
|
||||
const index = this.histories.findIndex((state) => {
|
||||
return state.time === cstate.time;
|
||||
});
|
||||
if (index !== -1) {
|
||||
this.histories.splice(index, 1);
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
}
|
||||
},
|
||||
|
||||
async handleSend() {
|
||||
const el = document.getElementById("input-form");
|
||||
const value = el.value.trim();
|
||||
if (!value) return;
|
||||
|
||||
if (this.generating) return;
|
||||
this.maxContextReached = false;
|
||||
this.placeholderText = "Generating...";
|
||||
this.generating = true;
|
||||
this.cancelGeneration = false;
|
||||
if (this.home === 0) this.home = 1;
|
||||
|
||||
// ensure that going back in history will go back to home
|
||||
window.history.pushState({}, "", window.TINYCHAT_ROOT || "/");
|
||||
|
||||
// add message to list
|
||||
this.cstate.messages.push({ role: "user", content: value });
|
||||
|
||||
// clear textarea
|
||||
el.value = "";
|
||||
el.style.height = "auto";
|
||||
el.style.height = el.scrollHeight + "px";
|
||||
|
||||
// reset performance tracking
|
||||
const prefill_start = Date.now();
|
||||
let start_time = 0;
|
||||
let tokens = 0;
|
||||
this.tokens_per_second = 0;
|
||||
|
||||
let gottenFirstChunk = false;
|
||||
try {
|
||||
for await (
|
||||
const chunk of this.openaiChatCompletion(this.cstate.messages)
|
||||
) {
|
||||
if (!gottenFirstChunk) {
|
||||
this.cstate.messages.push({ role: "assistant", content: "" });
|
||||
gottenFirstChunk = true;
|
||||
}
|
||||
|
||||
// add chunk to the last message
|
||||
// TODO: handle localStorage overflow
|
||||
// possible example: this.cstate.messages[...] was undefined when trying to prompt within an old cstate (chat session)
|
||||
this.cstate.messages[this.cstate.messages.length - 1].content += chunk;
|
||||
|
||||
// calculate performance tracking
|
||||
tokens += 1;
|
||||
this.total_tokens += 1;
|
||||
if (start_time === 0) {
|
||||
start_time = Date.now();
|
||||
this.time_till_first = start_time - prefill_start;
|
||||
} else {
|
||||
const diff = Date.now() - start_time;
|
||||
if (diff > 0) {
|
||||
this.tokens_per_second = tokens / (diff / 1000);
|
||||
}
|
||||
}
|
||||
this.checkMaxContext(this.total_tokens);
|
||||
if (this.cancelGeneration) break;
|
||||
}
|
||||
} finally {
|
||||
// update the state in histories or add it if it doesn't exist
|
||||
const index = this.histories.findIndex((cstate) => {
|
||||
return cstate.time === this.cstate.time;
|
||||
});
|
||||
this.cstate.time = Date.now();
|
||||
if (index !== -1) {
|
||||
// update the time
|
||||
this.histories[index] = this.cstate;
|
||||
} else {
|
||||
this.histories.push(this.cstate);
|
||||
}
|
||||
// update in local storage
|
||||
localStorage.setItem("histories", JSON.stringify(this.histories));
|
||||
|
||||
if (!this.maxContextReached) this.generating = false;
|
||||
if (this.cancelGeneration && !this.maxContextReached) this.cstate = { time: null, messages: [] };
|
||||
}
|
||||
},
|
||||
|
||||
async handleEnter(event) {
|
||||
// if shift is not pressed
|
||||
if (!event.shiftKey) {
|
||||
event.preventDefault();
|
||||
await this.handleSend();
|
||||
}
|
||||
},
|
||||
|
||||
updateTotalTokens(messages) {
|
||||
try {
|
||||
let toks = [this.tokenizer.bos_id];
|
||||
messages.forEach((message) => {
|
||||
if (!message.role || !message.content) {
|
||||
throw new Error("Each message must have a 'role' and 'content' property.");
|
||||
}
|
||||
toks = toks.concat(this.tokenizer.encodeMessage(message.role, message.content));
|
||||
|
||||
if (messages.length > 0 && messages[messages.length - 1].role === "user") {
|
||||
toks = toks.concat(this.tokenizer.encodeRole("assistant"));
|
||||
}
|
||||
this.total_tokens = toks.length;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error updating total tokens:", error);
|
||||
}
|
||||
},
|
||||
|
||||
checkMaxContext(num_tokens) {
|
||||
if (num_tokens >= this.max_context) {
|
||||
this.cancelGeneration = true;
|
||||
this.maxContextReached = true;
|
||||
this.placeholderText = `Max context reached: ${this.max_context} tokens`;
|
||||
}
|
||||
},
|
||||
|
||||
async *openaiChatCompletion(messages) {
|
||||
let tokens = [this.tokenizer.bos_id];
|
||||
for (const message of messages) {
|
||||
tokens = tokens.concat(this.tokenizer.encodeMessage(message.role, message.content));
|
||||
}
|
||||
tokens = tokens.concat(this.tokenizer.encodeRole("assistant"));
|
||||
this.checkMaxContext(tokens.length); // don't waste time prefilling if we know we're over the token limit
|
||||
let startPos = 0
|
||||
const prefillToks = tokens.slice(0, -1);
|
||||
|
||||
// Skip the largest possible sequence of tokens already represented at the beginning of the model's kv caches
|
||||
for (let i=0; i <= prefillToks.length; i++) {
|
||||
startPos = i;
|
||||
if (i == prefillToks.length) break;
|
||||
if (i == this.lastSeenToks.length) break;
|
||||
if (prefillToks[i] !== this.lastSeenToks[i]) break;
|
||||
}
|
||||
//this.lastSeenToks = prefillToks;
|
||||
//prefillToks = prefillToks.slice(startPos);
|
||||
const unprocessedPrefillToks = prefillToks.slice(startPos);
|
||||
this.lastSeenToks = prefillToks.slice(0, startPos);
|
||||
|
||||
this.progress = makeProgress(unprocessedPrefillToks.length);
|
||||
this.loadingMessage = (window.BACKEND === "WebGPU") ? "Reading input:" : "Loading (enable WebGPU for speed):";
|
||||
this.progress(0, this.loadingMessage);
|
||||
for (const tok of unprocessedPrefillToks) {
|
||||
if (this.cancelGeneration) {this.loadingMessage=""; return;}
|
||||
if (window.BACKEND === "WebGPU") {await this.nets["transformer"](new Int32Array([tok]), new Int32Array([startPos]));}
|
||||
else {await this.nets["transformer"](tok, startPos);}
|
||||
this.lastSeenToks.push(tok)
|
||||
startPos += 1;
|
||||
this.progress(1);
|
||||
}
|
||||
this.loadingMessage = ""; // hides progress bar
|
||||
|
||||
let lastTok = tokens[tokens.length - 1];
|
||||
while (true) {
|
||||
if (window.BACKEND === "WebGPU") {var tok = await this.nets["transformer"](new Int32Array([lastTok]), new Int32Array([startPos])); tok = tok[0][0];}
|
||||
else {var tok = await this.nets["transformer"](lastTok, startPos);}
|
||||
this.lastSeenToks.push(lastTok); // lets us skip prefilling with these tokens at the next prompt in this chain
|
||||
startPos += 1;
|
||||
lastTok = tok;
|
||||
if (this.tokenizer.stop_tokens.has(lastTok)) break;
|
||||
yield new TextDecoder().decode(this.tokenizer.decode([lastTok]));
|
||||
}
|
||||
},
|
||||
}));
|
||||
});
|
||||
|
||||
const { markedHighlight } = globalThis.markedHighlight;
|
||||
marked.use(markedHighlight({
|
||||
langPrefix: "hljs language-",
|
||||
highlight(code, lang, _info) {
|
||||
const language = hljs.getLanguage(lang) ? lang : "plaintext";
|
||||
return hljs.highlight(code, { language }).value;
|
||||
},
|
||||
}));
|
||||
|
||||
// **** eventsource-parser ****
|
||||
class EventSourceParserStream extends TransformStream {
|
||||
constructor() {
|
||||
let parser;
|
||||
|
||||
super({
|
||||
start(controller) {
|
||||
parser = createParser((event) => {
|
||||
if (event.type === "event") {
|
||||
controller.enqueue(event);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
transform(chunk) {
|
||||
parser.feed(chunk);
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function createParser(onParse) {
|
||||
let isFirstChunk;
|
||||
let buffer;
|
||||
let startingPosition;
|
||||
let startingFieldLength;
|
||||
let eventId;
|
||||
let eventName;
|
||||
let data;
|
||||
reset();
|
||||
return {
|
||||
feed,
|
||||
reset,
|
||||
};
|
||||
function reset() {
|
||||
isFirstChunk = true;
|
||||
buffer = "";
|
||||
startingPosition = 0;
|
||||
startingFieldLength = -1;
|
||||
eventId = void 0;
|
||||
eventName = void 0;
|
||||
data = "";
|
||||
}
|
||||
function feed(chunk) {
|
||||
buffer = buffer ? buffer + chunk : chunk;
|
||||
if (isFirstChunk && hasBom(buffer)) {
|
||||
buffer = buffer.slice(BOM.length);
|
||||
}
|
||||
isFirstChunk = false;
|
||||
const length = buffer.length;
|
||||
let position = 0;
|
||||
let discardTrailingNewline = false;
|
||||
while (position < length) {
|
||||
if (discardTrailingNewline) {
|
||||
if (buffer[position] === "\n") {
|
||||
++position;
|
||||
}
|
||||
discardTrailingNewline = false;
|
||||
}
|
||||
let lineLength = -1;
|
||||
let fieldLength = startingFieldLength;
|
||||
let character;
|
||||
for (
|
||||
let index = startingPosition;
|
||||
lineLength < 0 && index < length;
|
||||
++index
|
||||
) {
|
||||
character = buffer[index];
|
||||
if (character === ":" && fieldLength < 0) {
|
||||
fieldLength = index - position;
|
||||
} else if (character === "\r") {
|
||||
discardTrailingNewline = true;
|
||||
lineLength = index - position;
|
||||
} else if (character === "\n") {
|
||||
lineLength = index - position;
|
||||
}
|
||||
}
|
||||
if (lineLength < 0) {
|
||||
startingPosition = length - position;
|
||||
startingFieldLength = fieldLength;
|
||||
break;
|
||||
} else {
|
||||
startingPosition = 0;
|
||||
startingFieldLength = -1;
|
||||
}
|
||||
parseEventStreamLine(buffer, position, fieldLength, lineLength);
|
||||
position += lineLength + 1;
|
||||
}
|
||||
if (position === length) {
|
||||
buffer = "";
|
||||
} else if (position > 0) {
|
||||
buffer = buffer.slice(position);
|
||||
}
|
||||
}
|
||||
function parseEventStreamLine(lineBuffer, index, fieldLength, lineLength) {
|
||||
if (lineLength === 0) {
|
||||
if (data.length > 0) {
|
||||
onParse({
|
||||
type: "event",
|
||||
id: eventId,
|
||||
event: eventName || void 0,
|
||||
data: data.slice(0, -1),
|
||||
// remove trailing newline
|
||||
});
|
||||
|
||||
data = "";
|
||||
eventId = void 0;
|
||||
}
|
||||
eventName = void 0;
|
||||
return;
|
||||
}
|
||||
const noValue = fieldLength < 0;
|
||||
const field = lineBuffer.slice(
|
||||
index,
|
||||
index + (noValue ? lineLength : fieldLength),
|
||||
);
|
||||
let step = 0;
|
||||
if (noValue) {
|
||||
step = lineLength;
|
||||
} else if (lineBuffer[index + fieldLength + 1] === " ") {
|
||||
step = fieldLength + 2;
|
||||
} else {
|
||||
step = fieldLength + 1;
|
||||
}
|
||||
const position = index + step;
|
||||
const valueLength = lineLength - step;
|
||||
const value = lineBuffer.slice(position, position + valueLength).toString();
|
||||
if (field === "data") {
|
||||
data += value ? "".concat(value, "\n") : "\n";
|
||||
} else if (field === "event") {
|
||||
eventName = value;
|
||||
} else if (field === "id" && !value.includes("\0")) {
|
||||
eventId = value;
|
||||
} else if (field === "retry") {
|
||||
const retry = parseInt(value, 10);
|
||||
if (!Number.isNaN(retry)) {
|
||||
onParse({
|
||||
type: "reconnect-interval",
|
||||
value: retry,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const BOM = [239, 187, 191];
|
||||
function hasBom(buffer) {
|
||||
return BOM.every((charCode, index) => buffer.charCodeAt(index) === charCode);
|
||||
}
|
||||
|
||||
const PAT_STR = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
|
||||
|
||||
async function createTokenizer(bpeUrl) {
|
||||
const num_base_tokens = 128000;
|
||||
const special_tokens = {
|
||||
"<|begin_of_text|>": 128000,
|
||||
"<|end_of_text|>": 128001,
|
||||
"<|start_header_id|>": 128006,
|
||||
"<|end_header_id|>": 128007,
|
||||
"<|eot_id|>": 128009
|
||||
};
|
||||
const model = await window.tiktokenLoad({
|
||||
"load_tiktoken_bpe": bpeUrl,
|
||||
"special_tokens": special_tokens,
|
||||
"pat_str": PAT_STR
|
||||
});
|
||||
const tokenizer = new window.Tiktoken(model.bpe_ranks, model.special_tokens, model.pat_str)
|
||||
|
||||
return {
|
||||
get bos_id() {
|
||||
return special_tokens["<|begin_of_text|>"];
|
||||
},
|
||||
|
||||
get stop_tokens() {
|
||||
return new Set([
|
||||
special_tokens["<|end_of_text|>"],
|
||||
special_tokens["<|eot_id|>"],
|
||||
]);
|
||||
},
|
||||
|
||||
decode(toks) {
|
||||
const filtered = toks.filter((t) => t < num_base_tokens);
|
||||
return tokenizer.decode(filtered);
|
||||
},
|
||||
|
||||
encode(text, allow_special = false) {
|
||||
const allowedSpecial = allow_special ? "all" : new Set();
|
||||
const disallowedSpecial = new Set();
|
||||
return tokenizer.encode(text, allowedSpecial, disallowedSpecial);
|
||||
},
|
||||
|
||||
encodeRole(role) {
|
||||
const tokens = [];
|
||||
tokens.push(special_tokens["<|start_header_id|>"]);
|
||||
tokens.push(...this.encode(role));
|
||||
tokens.push(special_tokens["<|end_header_id|>"]);
|
||||
tokens.push(...this.encode("\n\n"));
|
||||
return tokens;
|
||||
},
|
||||
|
||||
encodeMessage(role, content) {
|
||||
const roleTokens = this.encodeRole(role);
|
||||
const contentTokens = this.encode(content.trim());
|
||||
return [...roleTokens, ...contentTokens, special_tokens["<|eot_id|>"]];
|
||||
},
|
||||
};
|
||||
}
|
||||
11
examples/tinychat/tinychat-browser/make_tiktoken_js.sh
Executable file
11
examples/tinychat/tinychat-browser/make_tiktoken_js.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
cd "$(dirname "$0")"
|
||||
npm init -y && \
|
||||
npm install --save-dev webpack webpack-cli && \
|
||||
npm install tiktoken && \
|
||||
jq '.scripts.build = "webpack"' package.json > package.tmp.json && \
|
||||
mv package.tmp.json package.json && \
|
||||
npm run build && \
|
||||
mv dist/*.wasm ./tiktoken_bg.wasm && \
|
||||
mv dist/* ./ && \
|
||||
rm -rf dist node_modules package-lock.json package.json
|
||||
5
examples/tinychat/tinychat-browser/tiktoken-export.js
Normal file
5
examples/tinychat/tinychat-browser/tiktoken-export.js
Normal file
@@ -0,0 +1,5 @@
|
||||
// Force Webpack to copy the WASM
|
||||
import 'tiktoken/tiktoken_bg.wasm';
|
||||
import { init, get_encoding, encoding_for_model, Tiktoken } from 'tiktoken/init';
|
||||
import { load } from 'tiktoken/load';
|
||||
export { init, get_encoding, encoding_for_model, Tiktoken, load };
|
||||
25
examples/tinychat/tinychat-browser/webpack.config.js
Normal file
25
examples/tinychat/tinychat-browser/webpack.config.js
Normal file
@@ -0,0 +1,25 @@
|
||||
const path = require("path");
|
||||
|
||||
module.exports = {
|
||||
mode: "production",
|
||||
entry: "./tiktoken-export.js",
|
||||
output: {
|
||||
filename: "tiktoken.js",
|
||||
path: path.resolve(__dirname, "dist"),
|
||||
library: {
|
||||
type: "module"
|
||||
}
|
||||
},
|
||||
experiments: {
|
||||
outputModule: true,
|
||||
asyncWebAssembly: true
|
||||
},
|
||||
module: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.wasm$/,
|
||||
type: "asset/resource",
|
||||
}
|
||||
]
|
||||
}
|
||||
};
|
||||
62
examples/tinychat/tinychat-browser/worker.js
Normal file
62
examples/tinychat/tinychat-browser/worker.js
Normal file
@@ -0,0 +1,62 @@
|
||||
const kernelsReady = (async () => {
|
||||
// can't get browser to use updated versions except with cache-busting query string
|
||||
const exports = await import(`./net_clang.js?version=${Date.now()}`);
|
||||
Object.assign(self, exports);
|
||||
})();
|
||||
|
||||
async function init(event) {
|
||||
await kernelsReady;
|
||||
self.model = await self.transformer();
|
||||
self.addEventListener("message", loadStateDict);
|
||||
self.removeEventListener("message", init);
|
||||
self.postMessage("success");
|
||||
}
|
||||
|
||||
function loadStateDict(event) {
|
||||
if (event.data === "done") {
|
||||
self.addEventListener("message", inference);
|
||||
self.removeEventListener("message", loadStateDict);
|
||||
}
|
||||
else {
|
||||
if (event.data.length > 1) {
|
||||
// the bytes from files are set contiguously in WASM memory
|
||||
const malloc_size = event.data.reduce((sum, file) => sum + file.bytes.length, 0);
|
||||
const malloc_ptr = self.model.wasm._malloc(malloc_size);
|
||||
let cursor = 0;
|
||||
for (const file of event.data) {
|
||||
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr + cursor);
|
||||
for (const part of file.parts) {
|
||||
if (part.target_start_pos === 0) {
|
||||
// tell WASM code where the tensor is in memory
|
||||
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + cursor);
|
||||
}
|
||||
cursor += part.size;
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// the bytes from files are not guaranteed to be set contiguously in WASM memory
|
||||
const file = event.data[0];
|
||||
const malloc_ptr = self.model.wasm._malloc(file.size);
|
||||
self.model.wasm.HEAPU8.set(file.bytes, malloc_ptr);
|
||||
for (const part of file.parts) {
|
||||
if (part.target_start_pos === 0) {
|
||||
self.model.wasm._set_buf(self.transformer_name_to_id[part.key], malloc_ptr + part.file_start_pos);
|
||||
}
|
||||
}
|
||||
file.bytes = null;
|
||||
}
|
||||
}
|
||||
self.postMessage("success");
|
||||
}
|
||||
|
||||
function inference(event) {
|
||||
const [tok, start_pos] = event.data;
|
||||
const int32tok = new Int32Array([tok]);
|
||||
const model_out = self.model.run(new Uint8Array(int32tok.buffer), start_pos);
|
||||
const int32nextTok = new Int32Array(model_out[0].buffer);
|
||||
self.postMessage(int32nextTok[0]);
|
||||
}
|
||||
|
||||
self.addEventListener("message", init);
|
||||
209
extra/amdpci/headers/hdp_6_0_0_offset.h
Normal file
209
extra/amdpci/headers/hdp_6_0_0_offset.h
Normal file
@@ -0,0 +1,209 @@
|
||||
/*
|
||||
* Copyright 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE COPYRIGHT HOLDER(S) OR AUTHOR(S) BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#ifndef _hdp_6_0_0_OFFSET_HEADER
|
||||
#define _hdp_6_0_0_OFFSET_HEADER
|
||||
|
||||
|
||||
|
||||
// addressBlock: hdp_hdpdec
|
||||
// base address: 0x3c80
|
||||
#define regHDP_NONSURFACE_BASE 0x0040
|
||||
#define regHDP_NONSURFACE_BASE_BASE_IDX 0
|
||||
#define regHDP_NONSURFACE_INFO 0x0041
|
||||
#define regHDP_NONSURFACE_INFO_BASE_IDX 0
|
||||
#define regHDP_NONSURFACE_BASE_HI 0x0042
|
||||
#define regHDP_NONSURFACE_BASE_HI_BASE_IDX 0
|
||||
#define regHDP_SURFACE_WRITE_FLAGS 0x00c4
|
||||
#define regHDP_SURFACE_WRITE_FLAGS_BASE_IDX 0
|
||||
#define regHDP_SURFACE_READ_FLAGS 0x00c5
|
||||
#define regHDP_SURFACE_READ_FLAGS_BASE_IDX 0
|
||||
#define regHDP_SURFACE_WRITE_FLAGS_CLR 0x00c6
|
||||
#define regHDP_SURFACE_WRITE_FLAGS_CLR_BASE_IDX 0
|
||||
#define regHDP_SURFACE_READ_FLAGS_CLR 0x00c7
|
||||
#define regHDP_SURFACE_READ_FLAGS_CLR_BASE_IDX 0
|
||||
#define regHDP_NONSURF_FLAGS 0x00c8
|
||||
#define regHDP_NONSURF_FLAGS_BASE_IDX 0
|
||||
#define regHDP_NONSURF_FLAGS_CLR 0x00c9
|
||||
#define regHDP_NONSURF_FLAGS_CLR_BASE_IDX 0
|
||||
#define regHDP_HOST_PATH_CNTL 0x00cc
|
||||
#define regHDP_HOST_PATH_CNTL_BASE_IDX 0
|
||||
#define regHDP_SW_SEMAPHORE 0x00cd
|
||||
#define regHDP_SW_SEMAPHORE_BASE_IDX 0
|
||||
#define regHDP_DEBUG0 0x00ce
|
||||
#define regHDP_DEBUG0_BASE_IDX 0
|
||||
#define regHDP_LAST_SURFACE_HIT 0x00d0
|
||||
#define regHDP_LAST_SURFACE_HIT_BASE_IDX 0
|
||||
#define regHDP_OUTSTANDING_REQ 0x00d2
|
||||
#define regHDP_OUTSTANDING_REQ_BASE_IDX 0
|
||||
#define regHDP_MISC_CNTL 0x00d3
|
||||
#define regHDP_MISC_CNTL_BASE_IDX 0
|
||||
#define regHDP_MEM_POWER_CTRL 0x00d4
|
||||
#define regHDP_MEM_POWER_CTRL_BASE_IDX 0
|
||||
#define regHDP_MMHUB_CNTL 0x00d5
|
||||
#define regHDP_MMHUB_CNTL_BASE_IDX 0
|
||||
#define regHDP_VERSION 0x00d7
|
||||
#define regHDP_VERSION_BASE_IDX 0
|
||||
#define regHDP_CLK_CNTL 0x00d8
|
||||
#define regHDP_CLK_CNTL_BASE_IDX 0
|
||||
#define regHDP_MEMIO_CNTL 0x00f6
|
||||
#define regHDP_MEMIO_CNTL_BASE_IDX 0
|
||||
#define regHDP_MEMIO_ADDR 0x00f7
|
||||
#define regHDP_MEMIO_ADDR_BASE_IDX 0
|
||||
#define regHDP_MEMIO_STATUS 0x00f8
|
||||
#define regHDP_MEMIO_STATUS_BASE_IDX 0
|
||||
#define regHDP_MEMIO_WR_DATA 0x00f9
|
||||
#define regHDP_MEMIO_WR_DATA_BASE_IDX 0
|
||||
#define regHDP_MEMIO_RD_DATA 0x00fa
|
||||
#define regHDP_MEMIO_RD_DATA_BASE_IDX 0
|
||||
#define regHDP_XDP_DIRECT2HDP_FIRST 0x0100
|
||||
#define regHDP_XDP_DIRECT2HDP_FIRST_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_FLUSH 0x0101
|
||||
#define regHDP_XDP_D2H_FLUSH_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_BAR_UPDATE 0x0102
|
||||
#define regHDP_XDP_D2H_BAR_UPDATE_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_3 0x0103
|
||||
#define regHDP_XDP_D2H_RSVD_3_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_4 0x0104
|
||||
#define regHDP_XDP_D2H_RSVD_4_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_5 0x0105
|
||||
#define regHDP_XDP_D2H_RSVD_5_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_6 0x0106
|
||||
#define regHDP_XDP_D2H_RSVD_6_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_7 0x0107
|
||||
#define regHDP_XDP_D2H_RSVD_7_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_8 0x0108
|
||||
#define regHDP_XDP_D2H_RSVD_8_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_9 0x0109
|
||||
#define regHDP_XDP_D2H_RSVD_9_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_10 0x010a
|
||||
#define regHDP_XDP_D2H_RSVD_10_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_11 0x010b
|
||||
#define regHDP_XDP_D2H_RSVD_11_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_12 0x010c
|
||||
#define regHDP_XDP_D2H_RSVD_12_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_13 0x010d
|
||||
#define regHDP_XDP_D2H_RSVD_13_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_14 0x010e
|
||||
#define regHDP_XDP_D2H_RSVD_14_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_15 0x010f
|
||||
#define regHDP_XDP_D2H_RSVD_15_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_16 0x0110
|
||||
#define regHDP_XDP_D2H_RSVD_16_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_17 0x0111
|
||||
#define regHDP_XDP_D2H_RSVD_17_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_18 0x0112
|
||||
#define regHDP_XDP_D2H_RSVD_18_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_19 0x0113
|
||||
#define regHDP_XDP_D2H_RSVD_19_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_20 0x0114
|
||||
#define regHDP_XDP_D2H_RSVD_20_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_21 0x0115
|
||||
#define regHDP_XDP_D2H_RSVD_21_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_22 0x0116
|
||||
#define regHDP_XDP_D2H_RSVD_22_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_23 0x0117
|
||||
#define regHDP_XDP_D2H_RSVD_23_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_24 0x0118
|
||||
#define regHDP_XDP_D2H_RSVD_24_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_25 0x0119
|
||||
#define regHDP_XDP_D2H_RSVD_25_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_26 0x011a
|
||||
#define regHDP_XDP_D2H_RSVD_26_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_27 0x011b
|
||||
#define regHDP_XDP_D2H_RSVD_27_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_28 0x011c
|
||||
#define regHDP_XDP_D2H_RSVD_28_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_29 0x011d
|
||||
#define regHDP_XDP_D2H_RSVD_29_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_30 0x011e
|
||||
#define regHDP_XDP_D2H_RSVD_30_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_31 0x011f
|
||||
#define regHDP_XDP_D2H_RSVD_31_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_32 0x0120
|
||||
#define regHDP_XDP_D2H_RSVD_32_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_33 0x0121
|
||||
#define regHDP_XDP_D2H_RSVD_33_BASE_IDX 0
|
||||
#define regHDP_XDP_D2H_RSVD_34 0x0122
|
||||
#define regHDP_XDP_D2H_RSVD_34_BASE_IDX 0
|
||||
#define regHDP_XDP_DIRECT2HDP_LAST 0x0123
|
||||
#define regHDP_XDP_DIRECT2HDP_LAST_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR_CFG 0x0124
|
||||
#define regHDP_XDP_P2P_BAR_CFG_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_OFFSET 0x0125
|
||||
#define regHDP_XDP_P2P_MBX_OFFSET_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR0 0x0126
|
||||
#define regHDP_XDP_P2P_MBX_ADDR0_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR1 0x0127
|
||||
#define regHDP_XDP_P2P_MBX_ADDR1_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR2 0x0128
|
||||
#define regHDP_XDP_P2P_MBX_ADDR2_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR3 0x0129
|
||||
#define regHDP_XDP_P2P_MBX_ADDR3_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR4 0x012a
|
||||
#define regHDP_XDP_P2P_MBX_ADDR4_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR5 0x012b
|
||||
#define regHDP_XDP_P2P_MBX_ADDR5_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_MBX_ADDR6 0x012c
|
||||
#define regHDP_XDP_P2P_MBX_ADDR6_BASE_IDX 0
|
||||
#define regHDP_XDP_HDP_MBX_MC_CFG 0x012d
|
||||
#define regHDP_XDP_HDP_MBX_MC_CFG_BASE_IDX 0
|
||||
#define regHDP_XDP_HDP_MC_CFG 0x012e
|
||||
#define regHDP_XDP_HDP_MC_CFG_BASE_IDX 0
|
||||
#define regHDP_XDP_HST_CFG 0x012f
|
||||
#define regHDP_XDP_HST_CFG_BASE_IDX 0
|
||||
#define regHDP_XDP_HDP_IPH_CFG 0x0131
|
||||
#define regHDP_XDP_HDP_IPH_CFG_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR0 0x0134
|
||||
#define regHDP_XDP_P2P_BAR0_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR1 0x0135
|
||||
#define regHDP_XDP_P2P_BAR1_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR2 0x0136
|
||||
#define regHDP_XDP_P2P_BAR2_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR3 0x0137
|
||||
#define regHDP_XDP_P2P_BAR3_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR4 0x0138
|
||||
#define regHDP_XDP_P2P_BAR4_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR5 0x0139
|
||||
#define regHDP_XDP_P2P_BAR5_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR6 0x013a
|
||||
#define regHDP_XDP_P2P_BAR6_BASE_IDX 0
|
||||
#define regHDP_XDP_P2P_BAR7 0x013b
|
||||
#define regHDP_XDP_P2P_BAR7_BASE_IDX 0
|
||||
#define regHDP_XDP_FLUSH_ARMED_STS 0x013c
|
||||
#define regHDP_XDP_FLUSH_ARMED_STS_BASE_IDX 0
|
||||
#define regHDP_XDP_FLUSH_CNTR0_STS 0x013d
|
||||
#define regHDP_XDP_FLUSH_CNTR0_STS_BASE_IDX 0
|
||||
#define regHDP_XDP_BUSY_STS 0x013e
|
||||
#define regHDP_XDP_BUSY_STS_BASE_IDX 0
|
||||
#define regHDP_XDP_STICKY 0x013f
|
||||
#define regHDP_XDP_STICKY_BASE_IDX 0
|
||||
#define regHDP_XDP_CHKN 0x0140
|
||||
#define regHDP_XDP_CHKN_BASE_IDX 0
|
||||
#define regHDP_XDP_BARS_ADDR_39_36 0x0144
|
||||
#define regHDP_XDP_BARS_ADDR_39_36_BASE_IDX 0
|
||||
#define regHDP_XDP_MC_VM_FB_LOCATION_BASE 0x0145
|
||||
#define regHDP_XDP_MC_VM_FB_LOCATION_BASE_BASE_IDX 0
|
||||
#define regHDP_XDP_MMHUB_ERROR 0x014a
|
||||
#define regHDP_XDP_MMHUB_ERROR_BASE_IDX 0
|
||||
|
||||
#endif
|
||||
646
extra/amdpci/headers/hdp_6_0_0_sh_mask.h
Normal file
646
extra/amdpci/headers/hdp_6_0_0_sh_mask.h
Normal file
@@ -0,0 +1,646 @@
|
||||
/*
|
||||
* Copyright 2021 Advanced Micro Devices, Inc.
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining a
|
||||
* copy of this software and associated documentation files (the "Software"),
|
||||
* to deal in the Software without restriction, including without limitation
|
||||
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
* and/or sell copies of the Software, and to permit persons to whom the
|
||||
* Software is furnished to do so, subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be included in
|
||||
* all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||
* THE COPYRIGHT HOLDER(S) OR AUTHOR(S) BE LIABLE FOR ANY CLAIM, DAMAGES OR
|
||||
* OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
||||
* ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*
|
||||
*/
|
||||
#ifndef _hdp_6_0_0_SH_MASK_HEADER
|
||||
#define _hdp_6_0_0_SH_MASK_HEADER
|
||||
|
||||
|
||||
// addressBlock: hdp_hdpdec
|
||||
//HDP_NONSURFACE_BASE
|
||||
#define HDP_NONSURFACE_BASE__NONSURF_BASE_39_8__SHIFT 0x0
|
||||
#define HDP_NONSURFACE_BASE__NONSURF_BASE_39_8_MASK 0xFFFFFFFFL
|
||||
//HDP_NONSURFACE_INFO
|
||||
#define HDP_NONSURFACE_INFO__NONSURF_SWAP__SHIFT 0x4
|
||||
#define HDP_NONSURFACE_INFO__NONSURF_VMID__SHIFT 0x8
|
||||
#define HDP_NONSURFACE_INFO__NONSURF_SWAP_MASK 0x00000030L
|
||||
#define HDP_NONSURFACE_INFO__NONSURF_VMID_MASK 0x00000F00L
|
||||
//HDP_NONSURFACE_BASE_HI
|
||||
#define HDP_NONSURFACE_BASE_HI__NONSURF_BASE_47_40__SHIFT 0x0
|
||||
#define HDP_NONSURFACE_BASE_HI__NONSURF_BASE_47_40_MASK 0x000000FFL
|
||||
//HDP_SURFACE_WRITE_FLAGS
|
||||
#define HDP_SURFACE_WRITE_FLAGS__SURF0_WRITE_FLAG__SHIFT 0x0
|
||||
#define HDP_SURFACE_WRITE_FLAGS__SURF1_WRITE_FLAG__SHIFT 0x1
|
||||
#define HDP_SURFACE_WRITE_FLAGS__SURF0_WRITE_FLAG_MASK 0x00000001L
|
||||
#define HDP_SURFACE_WRITE_FLAGS__SURF1_WRITE_FLAG_MASK 0x00000002L
|
||||
//HDP_SURFACE_READ_FLAGS
|
||||
#define HDP_SURFACE_READ_FLAGS__SURF0_READ_FLAG__SHIFT 0x0
|
||||
#define HDP_SURFACE_READ_FLAGS__SURF1_READ_FLAG__SHIFT 0x1
|
||||
#define HDP_SURFACE_READ_FLAGS__SURF0_READ_FLAG_MASK 0x00000001L
|
||||
#define HDP_SURFACE_READ_FLAGS__SURF1_READ_FLAG_MASK 0x00000002L
|
||||
//HDP_SURFACE_WRITE_FLAGS_CLR
|
||||
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF0_WRITE_FLAG_CLR__SHIFT 0x0
|
||||
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF1_WRITE_FLAG_CLR__SHIFT 0x1
|
||||
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF0_WRITE_FLAG_CLR_MASK 0x00000001L
|
||||
#define HDP_SURFACE_WRITE_FLAGS_CLR__SURF1_WRITE_FLAG_CLR_MASK 0x00000002L
|
||||
//HDP_SURFACE_READ_FLAGS_CLR
|
||||
#define HDP_SURFACE_READ_FLAGS_CLR__SURF0_READ_FLAG_CLR__SHIFT 0x0
|
||||
#define HDP_SURFACE_READ_FLAGS_CLR__SURF1_READ_FLAG_CLR__SHIFT 0x1
|
||||
#define HDP_SURFACE_READ_FLAGS_CLR__SURF0_READ_FLAG_CLR_MASK 0x00000001L
|
||||
#define HDP_SURFACE_READ_FLAGS_CLR__SURF1_READ_FLAG_CLR_MASK 0x00000002L
|
||||
//HDP_NONSURF_FLAGS
|
||||
#define HDP_NONSURF_FLAGS__NONSURF_WRITE_FLAG__SHIFT 0x0
|
||||
#define HDP_NONSURF_FLAGS__NONSURF_READ_FLAG__SHIFT 0x1
|
||||
#define HDP_NONSURF_FLAGS__NONSURF_WRITE_FLAG_MASK 0x00000001L
|
||||
#define HDP_NONSURF_FLAGS__NONSURF_READ_FLAG_MASK 0x00000002L
|
||||
//HDP_NONSURF_FLAGS_CLR
|
||||
#define HDP_NONSURF_FLAGS_CLR__NONSURF_WRITE_FLAG_CLR__SHIFT 0x0
|
||||
#define HDP_NONSURF_FLAGS_CLR__NONSURF_READ_FLAG_CLR__SHIFT 0x1
|
||||
#define HDP_NONSURF_FLAGS_CLR__NONSURF_WRITE_FLAG_CLR_MASK 0x00000001L
|
||||
#define HDP_NONSURF_FLAGS_CLR__NONSURF_READ_FLAG_CLR_MASK 0x00000002L
|
||||
//HDP_HOST_PATH_CNTL
|
||||
#define HDP_HOST_PATH_CNTL__WR_STALL_TIMER__SHIFT 0x9
|
||||
#define HDP_HOST_PATH_CNTL__RD_STALL_TIMER__SHIFT 0xb
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_PRELOAD_CFG__SHIFT 0x12
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER__SHIFT 0x13
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_EN__SHIFT 0x15
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_64B_EN__SHIFT 0x16
|
||||
#define HDP_HOST_PATH_CNTL__ALL_SURFACES_DIS__SHIFT 0x1d
|
||||
#define HDP_HOST_PATH_CNTL__WR_STALL_TIMER_MASK 0x00000600L
|
||||
#define HDP_HOST_PATH_CNTL__RD_STALL_TIMER_MASK 0x00001800L
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_PRELOAD_CFG_MASK 0x00040000L
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_TIMER_MASK 0x00180000L
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_EN_MASK 0x00200000L
|
||||
#define HDP_HOST_PATH_CNTL__WRITE_COMBINE_64B_EN_MASK 0x00400000L
|
||||
#define HDP_HOST_PATH_CNTL__ALL_SURFACES_DIS_MASK 0x20000000L
|
||||
//HDP_SW_SEMAPHORE
|
||||
#define HDP_SW_SEMAPHORE__SW_SEMAPHORE__SHIFT 0x0
|
||||
#define HDP_SW_SEMAPHORE__SW_SEMAPHORE_MASK 0xFFFFFFFFL
|
||||
//HDP_DEBUG0
|
||||
#define HDP_DEBUG0__HDP_DEBUG__SHIFT 0x0
|
||||
#define HDP_DEBUG0__HDP_DEBUG_MASK 0xFFFFFFFFL
|
||||
//HDP_LAST_SURFACE_HIT
|
||||
#define HDP_LAST_SURFACE_HIT__LAST_SURFACE_HIT__SHIFT 0x0
|
||||
#define HDP_LAST_SURFACE_HIT__LAST_SURFACE_HIT_MASK 0x00000003L
|
||||
//HDP_OUTSTANDING_REQ
|
||||
#define HDP_OUTSTANDING_REQ__WRITE_REQ__SHIFT 0x0
|
||||
#define HDP_OUTSTANDING_REQ__READ_REQ__SHIFT 0x8
|
||||
#define HDP_OUTSTANDING_REQ__WRITE_REQ_MASK 0x000000FFL
|
||||
#define HDP_OUTSTANDING_REQ__READ_REQ_MASK 0x0000FF00L
|
||||
//HDP_MISC_CNTL
|
||||
#define HDP_MISC_CNTL__IDLE_HYSTERESIS_CNTL__SHIFT 0x2
|
||||
#define HDP_MISC_CNTL__OUTSTANDING_WRITE_COUNT_1024__SHIFT 0x5
|
||||
#define HDP_MISC_CNTL__MMHUB_EARLY_WRACK_ENABLE__SHIFT 0x8
|
||||
#define HDP_MISC_CNTL__SIMULTANEOUS_READS_WRITES__SHIFT 0xb
|
||||
#define HDP_MISC_CNTL__READ_BUFFER_WATERMARK__SHIFT 0xe
|
||||
#define HDP_MISC_CNTL__NACK_ENABLE__SHIFT 0x13
|
||||
#define HDP_MISC_CNTL__ATOMIC_NACK_ENABLE__SHIFT 0x14
|
||||
#define HDP_MISC_CNTL__FED_ENABLE__SHIFT 0x15
|
||||
#define HDP_MISC_CNTL__ATOMIC_FED_ENABLE__SHIFT 0x16
|
||||
#define HDP_MISC_CNTL__MMHUB_WRBURST_ENABLE__SHIFT 0x18
|
||||
#define HDP_MISC_CNTL__MMHUB_WRBURST_SIZE__SHIFT 0x1e
|
||||
#define HDP_MISC_CNTL__IDLE_HYSTERESIS_CNTL_MASK 0x0000000CL
|
||||
#define HDP_MISC_CNTL__OUTSTANDING_WRITE_COUNT_1024_MASK 0x00000020L
|
||||
#define HDP_MISC_CNTL__MMHUB_EARLY_WRACK_ENABLE_MASK 0x00000100L
|
||||
#define HDP_MISC_CNTL__SIMULTANEOUS_READS_WRITES_MASK 0x00000800L
|
||||
#define HDP_MISC_CNTL__READ_BUFFER_WATERMARK_MASK 0x0000C000L
|
||||
#define HDP_MISC_CNTL__NACK_ENABLE_MASK 0x00080000L
|
||||
#define HDP_MISC_CNTL__ATOMIC_NACK_ENABLE_MASK 0x00100000L
|
||||
#define HDP_MISC_CNTL__FED_ENABLE_MASK 0x00200000L
|
||||
#define HDP_MISC_CNTL__ATOMIC_FED_ENABLE_MASK 0x00400000L
|
||||
#define HDP_MISC_CNTL__MMHUB_WRBURST_ENABLE_MASK 0x01000000L
|
||||
#define HDP_MISC_CNTL__MMHUB_WRBURST_SIZE_MASK 0x40000000L
|
||||
//HDP_MEM_POWER_CTRL
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_CTRL_EN__SHIFT 0x0
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_LS_EN__SHIFT 0x1
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DS_EN__SHIFT 0x2
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_SD_EN__SHIFT 0x3
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_IDLE_HYSTERESIS__SHIFT 0x4
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_UP_RECOVER_DELAY__SHIFT 0x8
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DOWN_ENTER_DELAY__SHIFT 0xe
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_CTRL_EN__SHIFT 0x10
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_LS_EN__SHIFT 0x11
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DS_EN__SHIFT 0x12
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_SD_EN__SHIFT 0x13
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_IDLE_HYSTERESIS__SHIFT 0x14
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_UP_RECOVER_DELAY__SHIFT 0x18
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DOWN_ENTER_DELAY__SHIFT 0x1e
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_CTRL_EN_MASK 0x00000001L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_LS_EN_MASK 0x00000002L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DS_EN_MASK 0x00000004L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_SD_EN_MASK 0x00000008L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_IDLE_HYSTERESIS_MASK 0x00000070L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_UP_RECOVER_DELAY_MASK 0x00003F00L
|
||||
#define HDP_MEM_POWER_CTRL__ATOMIC_MEM_POWER_DOWN_ENTER_DELAY_MASK 0x0000C000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_CTRL_EN_MASK 0x00010000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_LS_EN_MASK 0x00020000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DS_EN_MASK 0x00040000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_SD_EN_MASK 0x00080000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_IDLE_HYSTERESIS_MASK 0x00700000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_UP_RECOVER_DELAY_MASK 0x3F000000L
|
||||
#define HDP_MEM_POWER_CTRL__RC_MEM_POWER_DOWN_ENTER_DELAY_MASK 0xC0000000L
|
||||
//HDP_MMHUB_CNTL
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_RO__SHIFT 0x0
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_GCC__SHIFT 0x1
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_SNOOP__SHIFT 0x2
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_RO_MASK 0x00000001L
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_GCC_MASK 0x00000002L
|
||||
#define HDP_MMHUB_CNTL__HDP_MMHUB_SNOOP_MASK 0x00000004L
|
||||
//HDP_VERSION
|
||||
#define HDP_VERSION__MINVER__SHIFT 0x0
|
||||
#define HDP_VERSION__MAJVER__SHIFT 0x8
|
||||
#define HDP_VERSION__REV__SHIFT 0x10
|
||||
#define HDP_VERSION__MINVER_MASK 0x000000FFL
|
||||
#define HDP_VERSION__MAJVER_MASK 0x0000FF00L
|
||||
#define HDP_VERSION__REV_MASK 0x00FF0000L
|
||||
//HDP_CLK_CNTL
|
||||
#define HDP_CLK_CNTL__REG_CLK_ENABLE_COUNT__SHIFT 0x0
|
||||
#define HDP_CLK_CNTL__ATOMIC_MEM_CLK_SOFT_OVERRIDE__SHIFT 0x1a
|
||||
#define HDP_CLK_CNTL__RC_MEM_CLK_SOFT_OVERRIDE__SHIFT 0x1b
|
||||
#define HDP_CLK_CNTL__DBUS_CLK_SOFT_OVERRIDE__SHIFT 0x1c
|
||||
#define HDP_CLK_CNTL__DYN_CLK_SOFT_OVERRIDE__SHIFT 0x1d
|
||||
#define HDP_CLK_CNTL__XDP_REG_CLK_SOFT_OVERRIDE__SHIFT 0x1e
|
||||
#define HDP_CLK_CNTL__HDP_REG_CLK_SOFT_OVERRIDE__SHIFT 0x1f
|
||||
#define HDP_CLK_CNTL__REG_CLK_ENABLE_COUNT_MASK 0x0000000FL
|
||||
#define HDP_CLK_CNTL__ATOMIC_MEM_CLK_SOFT_OVERRIDE_MASK 0x04000000L
|
||||
#define HDP_CLK_CNTL__RC_MEM_CLK_SOFT_OVERRIDE_MASK 0x08000000L
|
||||
#define HDP_CLK_CNTL__DBUS_CLK_SOFT_OVERRIDE_MASK 0x10000000L
|
||||
#define HDP_CLK_CNTL__DYN_CLK_SOFT_OVERRIDE_MASK 0x20000000L
|
||||
#define HDP_CLK_CNTL__XDP_REG_CLK_SOFT_OVERRIDE_MASK 0x40000000L
|
||||
#define HDP_CLK_CNTL__HDP_REG_CLK_SOFT_OVERRIDE_MASK 0x80000000L
|
||||
//HDP_MEMIO_CNTL
|
||||
#define HDP_MEMIO_CNTL__MEMIO_SEND__SHIFT 0x0
|
||||
#define HDP_MEMIO_CNTL__MEMIO_OP__SHIFT 0x1
|
||||
#define HDP_MEMIO_CNTL__MEMIO_BE__SHIFT 0x2
|
||||
#define HDP_MEMIO_CNTL__MEMIO_WR_STROBE__SHIFT 0x6
|
||||
#define HDP_MEMIO_CNTL__MEMIO_RD_STROBE__SHIFT 0x7
|
||||
#define HDP_MEMIO_CNTL__MEMIO_ADDR_UPPER__SHIFT 0x8
|
||||
#define HDP_MEMIO_CNTL__MEMIO_CLR_WR_ERROR__SHIFT 0xe
|
||||
#define HDP_MEMIO_CNTL__MEMIO_CLR_RD_ERROR__SHIFT 0xf
|
||||
#define HDP_MEMIO_CNTL__MEMIO_VF__SHIFT 0x10
|
||||
#define HDP_MEMIO_CNTL__MEMIO_VFID__SHIFT 0x11
|
||||
#define HDP_MEMIO_CNTL__MEMIO_SEND_MASK 0x00000001L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_OP_MASK 0x00000002L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_BE_MASK 0x0000003CL
|
||||
#define HDP_MEMIO_CNTL__MEMIO_WR_STROBE_MASK 0x00000040L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_RD_STROBE_MASK 0x00000080L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_ADDR_UPPER_MASK 0x00003F00L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_CLR_WR_ERROR_MASK 0x00004000L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_CLR_RD_ERROR_MASK 0x00008000L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_VF_MASK 0x00010000L
|
||||
#define HDP_MEMIO_CNTL__MEMIO_VFID_MASK 0x003E0000L
|
||||
//HDP_MEMIO_ADDR
|
||||
#define HDP_MEMIO_ADDR__MEMIO_ADDR_LOWER__SHIFT 0x0
|
||||
#define HDP_MEMIO_ADDR__MEMIO_ADDR_LOWER_MASK 0xFFFFFFFFL
|
||||
//HDP_MEMIO_STATUS
|
||||
#define HDP_MEMIO_STATUS__MEMIO_WR_STATUS__SHIFT 0x0
|
||||
#define HDP_MEMIO_STATUS__MEMIO_RD_STATUS__SHIFT 0x1
|
||||
#define HDP_MEMIO_STATUS__MEMIO_WR_ERROR__SHIFT 0x2
|
||||
#define HDP_MEMIO_STATUS__MEMIO_RD_ERROR__SHIFT 0x3
|
||||
#define HDP_MEMIO_STATUS__MEMIO_WR_STATUS_MASK 0x00000001L
|
||||
#define HDP_MEMIO_STATUS__MEMIO_RD_STATUS_MASK 0x00000002L
|
||||
#define HDP_MEMIO_STATUS__MEMIO_WR_ERROR_MASK 0x00000004L
|
||||
#define HDP_MEMIO_STATUS__MEMIO_RD_ERROR_MASK 0x00000008L
|
||||
//HDP_MEMIO_WR_DATA
|
||||
#define HDP_MEMIO_WR_DATA__MEMIO_WR_DATA__SHIFT 0x0
|
||||
#define HDP_MEMIO_WR_DATA__MEMIO_WR_DATA_MASK 0xFFFFFFFFL
|
||||
//HDP_MEMIO_RD_DATA
|
||||
#define HDP_MEMIO_RD_DATA__MEMIO_RD_DATA__SHIFT 0x0
|
||||
#define HDP_MEMIO_RD_DATA__MEMIO_RD_DATA_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_DIRECT2HDP_FIRST
|
||||
#define HDP_XDP_DIRECT2HDP_FIRST__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_DIRECT2HDP_FIRST__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_FLUSH
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_FLUSH_NUM__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ENC_DATA__SHIFT 0x4
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ADDR_SEL__SHIFT 0x8
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_XPB_CLG__SHIFT 0xb
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_SEND_HOST__SHIFT 0x10
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_ALTER_FLUSH_NUM__SHIFT 0x12
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_0__SHIFT 0x13
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_1__SHIFT 0x14
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_FLUSH_NUM_MASK 0x0000000FL
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ENC_DATA_MASK 0x000000F0L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_MBX_ADDR_SEL_MASK 0x00000700L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_XPB_CLG_MASK 0x0000F800L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_SEND_HOST_MASK 0x00010000L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_ALTER_FLUSH_NUM_MASK 0x00040000L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_0_MASK 0x00080000L
|
||||
#define HDP_XDP_D2H_FLUSH__D2H_FLUSH_RSVD_1_MASK 0x00100000L
|
||||
//HDP_XDP_D2H_BAR_UPDATE
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_FLUSH_NUM__SHIFT 0x10
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_BAR_NUM__SHIFT 0x14
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_FLUSH_NUM_MASK 0x000F0000L
|
||||
#define HDP_XDP_D2H_BAR_UPDATE__D2H_BAR_UPDATE_BAR_NUM_MASK 0x00700000L
|
||||
//HDP_XDP_D2H_RSVD_3
|
||||
#define HDP_XDP_D2H_RSVD_3__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_3__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_4
|
||||
#define HDP_XDP_D2H_RSVD_4__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_4__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_5
|
||||
#define HDP_XDP_D2H_RSVD_5__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_5__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_6
|
||||
#define HDP_XDP_D2H_RSVD_6__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_6__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_7
|
||||
#define HDP_XDP_D2H_RSVD_7__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_7__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_8
|
||||
#define HDP_XDP_D2H_RSVD_8__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_8__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_9
|
||||
#define HDP_XDP_D2H_RSVD_9__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_9__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_10
|
||||
#define HDP_XDP_D2H_RSVD_10__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_10__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_11
|
||||
#define HDP_XDP_D2H_RSVD_11__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_11__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_12
|
||||
#define HDP_XDP_D2H_RSVD_12__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_12__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_13
|
||||
#define HDP_XDP_D2H_RSVD_13__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_13__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_14
|
||||
#define HDP_XDP_D2H_RSVD_14__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_14__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_15
|
||||
#define HDP_XDP_D2H_RSVD_15__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_15__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_16
|
||||
#define HDP_XDP_D2H_RSVD_16__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_16__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_17
|
||||
#define HDP_XDP_D2H_RSVD_17__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_17__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_18
|
||||
#define HDP_XDP_D2H_RSVD_18__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_18__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_19
|
||||
#define HDP_XDP_D2H_RSVD_19__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_19__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_20
|
||||
#define HDP_XDP_D2H_RSVD_20__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_20__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_21
|
||||
#define HDP_XDP_D2H_RSVD_21__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_21__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_22
|
||||
#define HDP_XDP_D2H_RSVD_22__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_22__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_23
|
||||
#define HDP_XDP_D2H_RSVD_23__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_23__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_24
|
||||
#define HDP_XDP_D2H_RSVD_24__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_24__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_25
|
||||
#define HDP_XDP_D2H_RSVD_25__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_25__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_26
|
||||
#define HDP_XDP_D2H_RSVD_26__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_26__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_27
|
||||
#define HDP_XDP_D2H_RSVD_27__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_27__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_28
|
||||
#define HDP_XDP_D2H_RSVD_28__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_28__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_29
|
||||
#define HDP_XDP_D2H_RSVD_29__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_29__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_30
|
||||
#define HDP_XDP_D2H_RSVD_30__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_30__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_31
|
||||
#define HDP_XDP_D2H_RSVD_31__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_31__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_32
|
||||
#define HDP_XDP_D2H_RSVD_32__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_32__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_33
|
||||
#define HDP_XDP_D2H_RSVD_33__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_33__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_D2H_RSVD_34
|
||||
#define HDP_XDP_D2H_RSVD_34__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_D2H_RSVD_34__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_DIRECT2HDP_LAST
|
||||
#define HDP_XDP_DIRECT2HDP_LAST__RESERVED__SHIFT 0x0
|
||||
#define HDP_XDP_DIRECT2HDP_LAST__RESERVED_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_P2P_BAR_CFG
|
||||
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_ADDR_SIZE__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_BAR_FROM__SHIFT 0x4
|
||||
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_ADDR_SIZE_MASK 0x0000000FL
|
||||
#define HDP_XDP_P2P_BAR_CFG__P2P_BAR_CFG_BAR_FROM_MASK 0x00000030L
|
||||
//HDP_XDP_P2P_MBX_OFFSET
|
||||
#define HDP_XDP_P2P_MBX_OFFSET__P2P_MBX_OFFSET__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_OFFSET__P2P_MBX_OFFSET_MASK 0x0001FFFFL
|
||||
//HDP_XDP_P2P_MBX_ADDR0
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR0__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR1
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR1__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR2
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR2__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR3
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR3__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR4
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR4__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR5
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR5__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_P2P_MBX_ADDR6
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__VALID__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_35_19__SHIFT 0x3
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_47_40__SHIFT 0x18
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__VALID_MASK 0x00000001L
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_35_19_MASK 0x000FFFF8L
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_P2P_MBX_ADDR6__ADDR_47_40_MASK 0xFF000000L
|
||||
//HDP_XDP_HDP_MBX_MC_CFG
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_QOS__SHIFT 0x0
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SWAP__SHIFT 0x4
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_VMID__SHIFT 0x8
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_RO__SHIFT 0xc
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_GCC__SHIFT 0xd
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SNOOP__SHIFT 0xe
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_QOS_MASK 0x0000000FL
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SWAP_MASK 0x00000030L
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_VMID_MASK 0x00000F00L
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_RO_MASK 0x00001000L
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_GCC_MASK 0x00002000L
|
||||
#define HDP_XDP_HDP_MBX_MC_CFG__HDP_MBX_MC_CFG_TAP_WRREQ_SNOOP_MASK 0x00004000L
|
||||
//HDP_XDP_HDP_MC_CFG
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SNOOP__SHIFT 0x3
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SWAP__SHIFT 0x4
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_VMID__SHIFT 0x8
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_RO__SHIFT 0xc
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_GCC__SHIFT 0xd
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_XDP_HIGHER_PRI_THRESH__SHIFT 0xe
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SNOOP_MASK 0x00000008L
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_SWAP_MASK 0x00000030L
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_VMID_MASK 0x00000F00L
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_RO_MASK 0x00001000L
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_HST_TAP_REQ_GCC_MASK 0x00002000L
|
||||
#define HDP_XDP_HDP_MC_CFG__HDP_MC_CFG_XDP_HIGHER_PRI_THRESH_MASK 0x000FC000L
|
||||
//HDP_XDP_HST_CFG
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_EN__SHIFT 0x0
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER__SHIFT 0x1
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_BURST_EN__SHIFT 0x3
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_64B_EN__SHIFT 0x4
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_PRELOAD_CFG__SHIFT 0x5
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_EN_MASK 0x00000001L
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_MASK 0x00000006L
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_BURST_EN_MASK 0x00000008L
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_64B_EN_MASK 0x00000010L
|
||||
#define HDP_XDP_HST_CFG__HST_CFG_WR_COMBINE_TIMER_PRELOAD_CFG_MASK 0x00000020L
|
||||
//HDP_XDP_HDP_IPH_CFG
|
||||
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_INVERSE_PEER_TAG_MATCHING__SHIFT 0xc
|
||||
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_P2P_RD_EN__SHIFT 0xd
|
||||
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_INVERSE_PEER_TAG_MATCHING_MASK 0x00001000L
|
||||
#define HDP_XDP_HDP_IPH_CFG__HDP_IPH_CFG_P2P_RD_EN_MASK 0x00002000L
|
||||
//HDP_XDP_P2P_BAR0
|
||||
#define HDP_XDP_P2P_BAR0__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR0__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR0__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR0__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR0__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR0__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR1
|
||||
#define HDP_XDP_P2P_BAR1__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR1__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR1__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR1__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR1__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR1__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR2
|
||||
#define HDP_XDP_P2P_BAR2__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR2__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR2__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR2__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR2__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR2__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR3
|
||||
#define HDP_XDP_P2P_BAR3__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR3__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR3__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR3__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR3__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR3__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR4
|
||||
#define HDP_XDP_P2P_BAR4__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR4__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR4__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR4__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR4__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR4__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR5
|
||||
#define HDP_XDP_P2P_BAR5__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR5__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR5__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR5__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR5__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR5__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR6
|
||||
#define HDP_XDP_P2P_BAR6__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR6__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR6__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR6__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR6__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR6__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_P2P_BAR7
|
||||
#define HDP_XDP_P2P_BAR7__ADDR__SHIFT 0x0
|
||||
#define HDP_XDP_P2P_BAR7__FLUSH__SHIFT 0x10
|
||||
#define HDP_XDP_P2P_BAR7__VALID__SHIFT 0x14
|
||||
#define HDP_XDP_P2P_BAR7__ADDR_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_P2P_BAR7__FLUSH_MASK 0x000F0000L
|
||||
#define HDP_XDP_P2P_BAR7__VALID_MASK 0x00100000L
|
||||
//HDP_XDP_FLUSH_ARMED_STS
|
||||
#define HDP_XDP_FLUSH_ARMED_STS__FLUSH_ARMED_STS__SHIFT 0x0
|
||||
#define HDP_XDP_FLUSH_ARMED_STS__FLUSH_ARMED_STS_MASK 0xFFFFFFFFL
|
||||
//HDP_XDP_FLUSH_CNTR0_STS
|
||||
#define HDP_XDP_FLUSH_CNTR0_STS__FLUSH_CNTR0_STS__SHIFT 0x0
|
||||
#define HDP_XDP_FLUSH_CNTR0_STS__FLUSH_CNTR0_STS_MASK 0x03FFFFFFL
|
||||
//HDP_XDP_BUSY_STS
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_0__SHIFT 0x0
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_1__SHIFT 0x1
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_2__SHIFT 0x2
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_3__SHIFT 0x3
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_4__SHIFT 0x4
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_5__SHIFT 0x5
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_6__SHIFT 0x6
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_7__SHIFT 0x7
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_8__SHIFT 0x8
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_9__SHIFT 0x9
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_10__SHIFT 0xa
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_11__SHIFT 0xb
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_12__SHIFT 0xc
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_13__SHIFT 0xd
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_14__SHIFT 0xe
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_15__SHIFT 0xf
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_16__SHIFT 0x10
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_17__SHIFT 0x11
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_18__SHIFT 0x12
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_19__SHIFT 0x13
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_20__SHIFT 0x14
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_21__SHIFT 0x15
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_22__SHIFT 0x16
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_23__SHIFT 0x17
|
||||
#define HDP_XDP_BUSY_STS__Z_FENCE_BIT__SHIFT 0x18
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_0_MASK 0x00000001L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_1_MASK 0x00000002L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_2_MASK 0x00000004L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_3_MASK 0x00000008L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_4_MASK 0x00000010L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_5_MASK 0x00000020L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_6_MASK 0x00000040L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_7_MASK 0x00000080L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_8_MASK 0x00000100L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_9_MASK 0x00000200L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_10_MASK 0x00000400L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_11_MASK 0x00000800L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_12_MASK 0x00001000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_13_MASK 0x00002000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_14_MASK 0x00004000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_15_MASK 0x00008000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_16_MASK 0x00010000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_17_MASK 0x00020000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_18_MASK 0x00040000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_19_MASK 0x00080000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_20_MASK 0x00100000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_21_MASK 0x00200000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_22_MASK 0x00400000L
|
||||
#define HDP_XDP_BUSY_STS__BUSY_BITS_23_MASK 0x00800000L
|
||||
#define HDP_XDP_BUSY_STS__Z_FENCE_BIT_MASK 0x01000000L
|
||||
//HDP_XDP_STICKY
|
||||
#define HDP_XDP_STICKY__STICKY_STS__SHIFT 0x0
|
||||
#define HDP_XDP_STICKY__STICKY_W1C__SHIFT 0x10
|
||||
#define HDP_XDP_STICKY__STICKY_STS_MASK 0x0000FFFFL
|
||||
#define HDP_XDP_STICKY__STICKY_W1C_MASK 0xFFFF0000L
|
||||
//HDP_XDP_CHKN
|
||||
#define HDP_XDP_CHKN__CHKN_0_RSVD__SHIFT 0x0
|
||||
#define HDP_XDP_CHKN__CHKN_1_RSVD__SHIFT 0x8
|
||||
#define HDP_XDP_CHKN__CHKN_2_RSVD__SHIFT 0x10
|
||||
#define HDP_XDP_CHKN__CHKN_3_RSVD__SHIFT 0x18
|
||||
#define HDP_XDP_CHKN__CHKN_0_RSVD_MASK 0x000000FFL
|
||||
#define HDP_XDP_CHKN__CHKN_1_RSVD_MASK 0x0000FF00L
|
||||
#define HDP_XDP_CHKN__CHKN_2_RSVD_MASK 0x00FF0000L
|
||||
#define HDP_XDP_CHKN__CHKN_3_RSVD_MASK 0xFF000000L
|
||||
//HDP_XDP_BARS_ADDR_39_36
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR0_ADDR_39_36__SHIFT 0x0
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR1_ADDR_39_36__SHIFT 0x4
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR2_ADDR_39_36__SHIFT 0x8
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR3_ADDR_39_36__SHIFT 0xc
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR4_ADDR_39_36__SHIFT 0x10
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR5_ADDR_39_36__SHIFT 0x14
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR6_ADDR_39_36__SHIFT 0x18
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR7_ADDR_39_36__SHIFT 0x1c
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR0_ADDR_39_36_MASK 0x0000000FL
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR1_ADDR_39_36_MASK 0x000000F0L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR2_ADDR_39_36_MASK 0x00000F00L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR3_ADDR_39_36_MASK 0x0000F000L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR4_ADDR_39_36_MASK 0x000F0000L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR5_ADDR_39_36_MASK 0x00F00000L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR6_ADDR_39_36_MASK 0x0F000000L
|
||||
#define HDP_XDP_BARS_ADDR_39_36__BAR7_ADDR_39_36_MASK 0xF0000000L
|
||||
//HDP_XDP_MC_VM_FB_LOCATION_BASE
|
||||
#define HDP_XDP_MC_VM_FB_LOCATION_BASE__FB_BASE__SHIFT 0x0
|
||||
#define HDP_XDP_MC_VM_FB_LOCATION_BASE__FB_BASE_MASK 0x03FFFFFFL
|
||||
//HDP_XDP_MMHUB_ERROR
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_01__SHIFT 0x1
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_10__SHIFT 0x2
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_11__SHIFT 0x3
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_FED__SHIFT 0x4
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_01__SHIFT 0x5
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_10__SHIFT 0x6
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_11__SHIFT 0x7
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_01__SHIFT 0x9
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_10__SHIFT 0xa
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_11__SHIFT 0xb
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_FED__SHIFT 0xc
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_01__SHIFT 0xd
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_10__SHIFT 0xe
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_11__SHIFT 0xf
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_01__SHIFT 0x11
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_10__SHIFT 0x12
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_11__SHIFT 0x13
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_01__SHIFT 0x15
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_10__SHIFT 0x16
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_11__SHIFT 0x17
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_01_MASK 0x00000002L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_10_MASK 0x00000004L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BRESP_11_MASK 0x00000008L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_FED_MASK 0x00000010L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_01_MASK 0x00000020L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_10_MASK 0x00000040L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_BUSER_NACK_11_MASK 0x00000080L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_01_MASK 0x00000200L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_10_MASK 0x00000400L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RRESP_11_MASK 0x00000800L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_FED_MASK 0x00001000L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_01_MASK 0x00002000L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_10_MASK 0x00004000L
|
||||
#define HDP_XDP_MMHUB_ERROR__HDP_RUSER_NACK_11_MASK 0x00008000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_01_MASK 0x00020000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_10_MASK 0x00040000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BRESP_11_MASK 0x00080000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_01_MASK 0x00200000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_10_MASK 0x00400000L
|
||||
#define HDP_XDP_MMHUB_ERROR__XDP_BUSER_NACK_11_MASK 0x00800000L
|
||||
|
||||
#endif
|
||||
@@ -6,7 +6,9 @@ from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import Ops
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
|
||||
|
||||
@@ -26,6 +28,7 @@ def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str]
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
cargs += [var for var in fxn.vars if getattr(var, "op", None) is Ops.DEFINE_VAR] # symbolic vars; is it necessary or sufficient to check for DEFINE_VAR?
|
||||
statements.append((fxn.function_name, cargs, fxn.global_size, fxn.local_size))
|
||||
|
||||
return functions, statements, {name:(size, dtype, key) for (name,size,dtype,key) in bufs.values()}, bufs_to_save
|
||||
@@ -54,60 +57,105 @@ def jit_model(model, *args) -> Tuple[TinyJit,Dict[int,str]]:
|
||||
special_names[id(output.lazydata.base.realized)] = f'output{i}'
|
||||
return run, special_names
|
||||
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str]) -> str:
|
||||
cprog = ["#include <tgmath.h>"]
|
||||
def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
|
||||
bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str:
|
||||
headers = ["#include <tgmath.h>"]
|
||||
cprog = list(functions.values())
|
||||
dtype_map = {dtypes.int: "int", dtypes.float: "float", dtypes.uchar: "unsigned char", dtypes.char: "signed char", dtypes.half: "__fp16", dtypes.uint: "unsigned int"}
|
||||
inputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in input_names + list(symbolic_vars.values())]
|
||||
outputs = [(name, dtype_map[bufs[name][1]], bufs[name][0]) for name in output_names]
|
||||
forward_args = ",".join(f"{dtype}{'*' if name not in symbolic_vars.values() else ''} {name}" for name,dtype,_ in (outputs+inputs if wasm else inputs+outputs))
|
||||
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
if not wasm:
|
||||
for name,cl in bufs_to_save.items():
|
||||
weight = ''.join(["\\x%02X"%x for x in bytes(cl._buf)])
|
||||
cprog.append(f"unsigned char {name}_data[] = \"{weight}\";")
|
||||
cprog += [f"{dtype_map[dtype]} {name}[{len}];" if name not in bufs_to_save else f"{dtype_map[dtype]} *{name} = ({dtype_map[dtype]} *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in input_names+output_names]
|
||||
cprog += [f"void net({forward_args}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(headers + cprog)
|
||||
else:
|
||||
if bufs_to_save:
|
||||
headers += ["#include <stddef.h>"]
|
||||
bufs_to_save = {k:v for k,v in bufs.items() if v[2] in weight_names} # causes random seeds to be set as zeroes, not exported as a model weight
|
||||
buf_to_name = OrderedDict((buf_name, {"name": weight_names[data[2]], "idx": i}) for i, (buf_name, data) in enumerate(bufs_to_save.items()))
|
||||
cprog.append(f"void* bufs[{len(buf_to_name)}];")
|
||||
cprog.append(f"""void set_buf(size_t index, void* ptr) {{\n bufs[index] = ptr;\n}}""")
|
||||
|
||||
inputs = ", ".join([f'float* {input}' for input in input_names])
|
||||
outputs = ", ".join([f'float* {output}' for output in output_names])
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,(len,dtype,_key) in bufs.items() if name not in ['input', 'outputs']]
|
||||
cprog += list(functions.values())
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
for name in set(bufs.keys()) - set(bufs_to_save.keys()) - set(input_names + output_names):
|
||||
n_bytes, dtype, _ = bufs[name]
|
||||
cprog += [f"{dtype_map[dtype]} {name}[{n_bytes // dtype.itemsize}];"]
|
||||
|
||||
cprog += [f"void net({forward_args})"] + ["{"]
|
||||
get_weight_ptr = lambda x: f"({dtype_map[bufs_to_save[x][1]]} *)bufs[{buf_to_name[x]['idx']}]" if x in bufs_to_save else x
|
||||
cprog += [f" {name}({', '.join(map(get_weight_ptr, args))});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
weightMapping = "" if not bufs_to_save else f"""\nconst weightNames = [{", ".join([f'"{weight_name}"' for weight_name in [v["name"] for v in buf_to_name.values()]])}];
|
||||
const {model_name}_name_to_id = Object.fromEntries(weightNames.map((name, index) => [name, index]));\n"""
|
||||
top = f"""import {model_name}Module from './{model_name}.js'{weightMapping}"""
|
||||
whitespace = "\n "
|
||||
js_wrapper = f"""{top}\nvar {model_name} = async function() {{
|
||||
const wasm = await {model_name}Module();
|
||||
{whitespace.join(f"const {name}Ptr = wasm._malloc({n_bytes});" for name, _, n_bytes in outputs+inputs if name not in symbolic_vars.values())}
|
||||
return {{
|
||||
run: ({",".join(name for name,_,_ in inputs)}) => {{
|
||||
{(whitespace + " ").join(f"wasm.HEAPU8.set({name}, {name}Ptr);" for name,_,_ in inputs if name not in symbolic_vars.values())}
|
||||
wasm._net({", ".join(f"{name}{'Ptr' if name not in symbolic_vars.values() else ''}" for name,_,_ in outputs+inputs)});
|
||||
{(whitespace + " ").join(f"const {name} = wasm.HEAPU8.slice({name}Ptr, {name}Ptr + {n_bytes});" for name,_,n_bytes in outputs)}
|
||||
return [{", ".join(f"{name}" for name,_,_ in outputs)}];
|
||||
}},
|
||||
wasm: wasm
|
||||
}}
|
||||
}}\nexport {{ {model_name}, {model_name}_name_to_id }};"""
|
||||
|
||||
return '\n'.join(headers + cprog), js_wrapper
|
||||
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name) -> Tuple[str,int,int]:
|
||||
exported_name = "model" if model_name == None else model_name
|
||||
def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
input_names += list(symbolic_vars.values())
|
||||
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
|
||||
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
|
||||
|
||||
buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage"
|
||||
create_bind_group_layouts = ",".join([
|
||||
"device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format(
|
||||
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'storage' }} }}" for argIdx, _ in enumerate(args)])
|
||||
",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)])
|
||||
)
|
||||
for _, (_, args, _, _) in enumerate(statements)
|
||||
])
|
||||
layouts = f"const layouts=[{create_bind_group_layouts}]"
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
|
||||
buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf"
|
||||
map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])"
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
|
||||
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
const {exported_name} = (() => {{
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}};
|
||||
|
||||
const getTensorMetadata = (safetensorBuffer) => {{
|
||||
getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};\n""" if not stream_weights else ""
|
||||
return f"""
|
||||
const {model_name} = (() => {{
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}};
|
||||
|
||||
{getTensorMetadata}
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createUniformBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}})
|
||||
}}
|
||||
|
||||
const createInfinityUniformBuf = (device) => {{
|
||||
const size = 4;
|
||||
const buf = device.createBuffer({{
|
||||
@@ -121,9 +169,8 @@ const createInfinityUniformBuf = (device) => {{
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
const buf = device.createBuffer({{ size, usage: GPUBufferUsage.STORAGE{" | GPUBufferUsage.COPY_DST" if stream_weights else ", mappedAtCreation: true"} }});
|
||||
{"data.bytes = buf;" if stream_weights else "new Uint8Array(buf.getMappedRange()).set(data); buf.unmap();"}
|
||||
return buf;
|
||||
}};
|
||||
|
||||
@@ -145,8 +192,8 @@ const addComputePass = (device, commandEncoder, pipeline, layout, infinityUnifor
|
||||
|
||||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
const setupNet = async (device, {"state_dict" if stream_weights else "safetensor"}) => {{
|
||||
{"const metadata = getTensorMetadata(safetensor);" if not stream_weights else ""}
|
||||
const infinityBuf = createInfinityUniformBuf(device);
|
||||
|
||||
{layouts}
|
||||
@@ -185,12 +232,12 @@ const setupNet = async (device, safetensor) => {{
|
||||
}}
|
||||
}}
|
||||
const load = async (device, weight_path) => {{ return await fetch(weight_path).then(x => x.arrayBuffer()).then(x => setupNet(device, new Uint8Array(x))); }}
|
||||
return {{ load }};
|
||||
return {{ load, setupNet }};
|
||||
}})();
|
||||
export default {exported_name};
|
||||
export default {model_name};
|
||||
"""
|
||||
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
@@ -198,11 +245,30 @@ def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
|
||||
weight_names = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
|
||||
# handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad
|
||||
symbolic_vars = OrderedDict()
|
||||
for i, (_, args, global_size, _) in enumerate(statements):
|
||||
for j, var in enumerate(args):
|
||||
if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str):
|
||||
if var not in symbolic_vars:
|
||||
symbolic_vars[var] = var.arg[0]
|
||||
bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var])
|
||||
statements[i][1][j] = symbolic_vars[var]
|
||||
|
||||
if global_size:
|
||||
for j, dim in enumerate(global_size):
|
||||
if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}:
|
||||
name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src)
|
||||
global_size[j] = f"_{name.arg[0]}[0] + {val.arg}"
|
||||
|
||||
prg = ""
|
||||
if target == "clang":
|
||||
prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names)
|
||||
elif target == "wasm":
|
||||
return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True)
|
||||
elif target == "webgpu":
|
||||
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name)
|
||||
prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights)
|
||||
else:
|
||||
prg = json.dumps({
|
||||
"backend": Device.DEFAULT,
|
||||
|
||||
1166
extra/hip_gpu_driver/sienna_cichlid_ip_offset.h
Normal file
1166
extra/hip_gpu_driver/sienna_cichlid_ip_offset.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -195,6 +195,7 @@ def get_onnx_ops():
|
||||
return [pads[i]-pads[i]//2 for i in range(len(pads))] + [pads[i]//2 for i in range(len(pads))]
|
||||
|
||||
def _resolve_pool_pads(x:Tensor, p_, k_, d_, s_, auto_pad:AUTO_PAD_OPTIONS):
|
||||
if auto_pad == "VALID": return [0]*(len(k_)*2)
|
||||
i_, (s_,d_,p_) = x.shape[-len(k_):], (make_tuple(x, len(k_)*2) for x in (s_, d_, p_))
|
||||
if auto_pad == "NOTSET": return _onnx_pads_to_tiny_pads(p_ if len(p_)==len(k_)*2 else p_*2)
|
||||
o_ = [((i - (1 if auto_pad in ("SAME_UPPER", "SAME_LOWER") else k)) // s + 1) for i,k,s in zip(i_, k_, s_)]
|
||||
@@ -673,7 +674,8 @@ def get_onnx_ops():
|
||||
x_sh = list(x.shape)
|
||||
ret_shape = x_sh[:axis] + list(indices.shape) + x_sh[axis+1:]
|
||||
if indices.ndim > 1: indices = indices.flatten()
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else [x_sh[axis]+x if x<0 else x for x in _cached_to_python_const(indices)]
|
||||
indices = [_cached_to_python_const(indices)] if indices.shape == () else _cached_to_python_const(indices)
|
||||
indices = [x_sh[axis]+x if x<0 else x for x in indices]
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices] # type: ignore
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
|
||||
@@ -174,6 +174,7 @@ decomps = [
|
||||
aten.threshold_backward,
|
||||
aten.softplus_backward,
|
||||
aten.elu, # elu has a scale + input_scale param
|
||||
aten.elu_backward,
|
||||
aten.softplus,
|
||||
aten.threshold,
|
||||
aten.nll_loss_forward,
|
||||
@@ -270,8 +271,8 @@ tiny_backend_out = {**{f"aten.{x}.out":getattr(Tensor,x) for x in simple_tensor_
|
||||
# TODO: this might result in overflow issues
|
||||
"aten.round.decimals_out": lambda self,decimals: (self*10**decimals).round()/10**decimals,
|
||||
# TODO: support this in tinygrad
|
||||
"aten.bitwise_left_shift.Tensor_out": lambda input,other: Tensor(input << other.numpy()),
|
||||
"aten.bitwise_right_shift.Tensor_out": lambda input,other: Tensor(input >> other.numpy()),
|
||||
"aten.bitwise_left_shift.Tensor_out": lambda x,y: x*(2**y),
|
||||
"aten.bitwise_right_shift.Tensor_out": lambda x,y: x//(2**y),
|
||||
# not in tinygrad. are there decomps for these?
|
||||
"aten.log10.out": lambda self: self.log2() * (math.log(2) / math.log(10)),
|
||||
"aten.log1p.out": lambda self: (self+1).log(),
|
||||
@@ -326,11 +327,12 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.var.correction": Tensor.var,
|
||||
"aten.var_mean.correction": Tensor.var_mean,
|
||||
# NOTE: axis=[] in torch means all, change tinygrad?
|
||||
"aten.sum.IntList_out": lambda self,axis,keepdim=False,out=None:
|
||||
out.replace(Tensor.sum(self, axis if axis is None or len(axis) else None, keepdim), allow_shape_mismatch=True),
|
||||
"aten.sum.IntList_out": lambda self,axis,keepdim=False,dtype=None,out=None:
|
||||
out.replace(self.sum(axis if axis is None or len(axis) else None, keepdim,
|
||||
acc_dtype = _from_torch_dtype(dtype) if dtype is not None else None), allow_shape_mismatch=True),
|
||||
"aten.scatter.value": Tensor.scatter,
|
||||
"aten.scatter.value_reduce": Tensor.scatter,
|
||||
"aten.gather": Tensor.gather,
|
||||
"aten.gather": lambda self, dim, index: self.gather(dim, index.cast(dtypes.int)),
|
||||
"aten.where.self": Tensor.where, # NOTE: this is needed as well as the out type
|
||||
"aten._softmax": lambda self,dim,half_to_float: self.softmax(dim),
|
||||
"aten._log_softmax": lambda self,dim,half_to_float: self.log_softmax(dim),
|
||||
@@ -343,10 +345,11 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
# these don't work in out form, they have size 0
|
||||
"aten.abs": Tensor.abs,
|
||||
"aten.logical_not": Tensor.logical_not,
|
||||
"aten.logical_or_": lambda x, y: x.assign(x | y),
|
||||
"aten.multinomial": Tensor.multinomial,
|
||||
"aten.pad": Tensor.pad,
|
||||
"aten.reflection_pad2d": functools.partial(Tensor.pad, mode="reflect"),
|
||||
"aten.masked_fill_.Scalar": lambda self,mask,value: self.assign(mask.where(self, value)),
|
||||
"aten.masked_fill_.Scalar": lambda self, mask, value: self.assign(self.masked_fill(mask, value)),
|
||||
"aten.masked_fill.Scalar": Tensor.masked_fill,
|
||||
"aten.masked_fill.Tensor": Tensor.masked_fill,
|
||||
"aten.all": Tensor.all,
|
||||
@@ -361,6 +364,14 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.atanh": Tensor.atanh,
|
||||
"aten.fill_.Tensor": Tensor.full,
|
||||
"aten.flip": Tensor.flip,
|
||||
"aten.scatter_reduce.two": Tensor.scatter_reduce,
|
||||
"aten.squeeze_.dim": lambda self, dim: self.replace(self.squeeze(dim), allow_shape_mismatch=True),
|
||||
"aten.add.Tensor": lambda input,other,alpha=1: input+alpha*other,
|
||||
"aten.linspace": lambda start, stop, steps, dtype=None, **kwargs:
|
||||
Tensor.linspace(start, stop, steps, **({"dtype": _from_torch_dtype(dtype)} if dtype is not None else {})),
|
||||
"aten::view.dtype": lambda self, dtype: self.bitcast(_from_torch_dtype(dtype)),
|
||||
"aten.constant_pad_nd": lambda self, padding, value=0.0: self.pad(padding, mode="constant", value=value),
|
||||
"aten.logsumexp": lambda self, axis, keepdim=False: self.logsumexp(axis[0], keepdim=keepdim),
|
||||
"aten.squeeze.dim": Tensor.squeeze,
|
||||
"aten.unsqueeze": Tensor.unsqueeze,
|
||||
"aten.roll": Tensor.roll,
|
||||
@@ -368,6 +379,9 @@ tiny_backend = {**{k:wrap_out(v) for k,v in tiny_backend_out.items()}, **{
|
||||
"aten.repeat": Tensor.repeat,
|
||||
"aten.lerp.Tensor": Tensor.lerp,
|
||||
"aten.expand": Tensor.expand,
|
||||
"aten.t": Tensor.transpose,
|
||||
"aten.detach": Tensor.detach,
|
||||
"aten.max.dim": lambda self, dim, keepdim=False: (self.max(dim, keepdim), self.argmax(dim, keepdim).cast(dtype=dtypes.int64))
|
||||
}}
|
||||
|
||||
def wrap_fxn(k,f):
|
||||
@@ -388,11 +402,11 @@ for k,v in tiny_backend.items(): torch.library.impl(k.replace("aten.", "aten::")
|
||||
if TORCH_DEBUG:
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
class DispatchLog(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args, kwargs=None):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
#print(f"Dispatch Log: {func}(*{args}, **{kwargs})")
|
||||
print(f"Dispatch Log: {func}")
|
||||
return func(*args, **(kwargs or {}))
|
||||
DispatchLog().__enter__()
|
||||
(_dispatch_log:=DispatchLog()).__enter__() # NOTE: must be kept alive
|
||||
|
||||
# NOTE: patch torch optimizer step to avoid continously growing the computation graph
|
||||
def realize_optimizer_step(optimizer: torch.optim.Optimizer, *args, **kwargs):
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestTorchBackend(unittest.TestCase):
|
||||
a = torch.ones(4, device=device)
|
||||
b = torch.ones(4, device=device)
|
||||
c = a == b
|
||||
print(c.cpu().numpy())
|
||||
print(c.cpu())
|
||||
|
||||
def test_maxpool2d_backward(self):
|
||||
x = torch.arange(3*3, device=device).reshape(1, 1, 3, 3).requires_grad_(True)
|
||||
@@ -86,10 +86,10 @@ class TestTorchBackend(unittest.TestCase):
|
||||
x = torch.zeros(4, device=device, dtype=torch.int64)
|
||||
y = torch.ones(4, device=device, dtype=torch.float32).to(dtype=torch.int64)
|
||||
res1 = x ^ y # an operation that only works on int types
|
||||
print(res1.cpu().numpy())
|
||||
print(res1.cpu())
|
||||
y = y.cpu().float().to(device=device, dtype=torch.int64)
|
||||
res2 = x ^ y
|
||||
print(res2.cpu().numpy())
|
||||
print(res2.cpu())
|
||||
|
||||
@unittest.skip("meh")
|
||||
def test_str(self):
|
||||
|
||||
@@ -30,7 +30,9 @@ def run(sz, n_gpus=6, iters=10, use_ring=False):
|
||||
with Context(RING=(2 if use_ring else 0), DEBUG=max(DEBUG.value, 2)): return test(devs, N, iters=iters)
|
||||
|
||||
def main():
|
||||
ONLY_RING = getenv("ONLY_RING", 0)
|
||||
n_gpus = getenv("GPUS", 6)
|
||||
iters = getenv("ITERS", 10)
|
||||
|
||||
if getenv("BENCHMARK_SPLIT"):
|
||||
l, r = 0, 512
|
||||
@@ -44,10 +46,10 @@ def main():
|
||||
else:
|
||||
sz = getenv("SZ", 1000) * 10**6 # size of data on each gpu
|
||||
print(f"Using {sz/10**9:.2f} GB of numbers on each of {n_gpus} GPUs, {n_gpus*sz/10**9:.2f} GB total.")
|
||||
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus)
|
||||
(naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus)
|
||||
(ring_gflops, ring_gbs, ring_secs) = run(sz, use_ring=True, n_gpus=n_gpus, iters=iters)
|
||||
if not ONLY_RING: (naive_gflops, naive_gbs, naive_secs) = run(sz, use_ring=False, n_gpus=n_gpus, iters=iters)
|
||||
print(f"Ring:\n {ring_secs:.6f} seconds/iter\n {ring_gflops:.2f} GFLOP/s\n {ring_gbs:.2f} GB/s")
|
||||
print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
if not ONLY_RING: print(f"Naive:\n {naive_secs:.6f} seconds/iter\n {naive_gflops:.2f} GFLOP/s\n {naive_gbs:.2f} GB/s")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
10
test/external/external_fuzz_ampt.py
vendored
10
test/external/external_fuzz_ampt.py
vendored
@@ -72,7 +72,7 @@ class AMPTFuzzer:
|
||||
|
||||
pattern = self.generate_pattern(ptr, size)
|
||||
pages = self.fill_memory(ptr, size, pattern)
|
||||
self.allocations[ptr] = (size, pattern, pages)
|
||||
self.allocations[ptr.va_addr] = (size, pattern, pages, ptr)
|
||||
self.alloc_payload += size
|
||||
print(f"Allocated {size} bytes at {ptr.va_addr:x}, pattern: {pattern:02x}")
|
||||
return ptr
|
||||
@@ -81,15 +81,15 @@ class AMPTFuzzer:
|
||||
if not self.allocations: return False
|
||||
|
||||
ptr = random.choice(list(self.allocations.keys()))
|
||||
size, pattern, pages = self.allocations[ptr]
|
||||
size, pattern, pages, vm = self.allocations[ptr]
|
||||
|
||||
# Verify pattern before freeing
|
||||
if not self.verify_memory(pages, pattern):
|
||||
raise RuntimeError(f"Memory corruption detected at {ptr.va_addr:x}!")
|
||||
raise RuntimeError(f"Memory corruption detected at {vm.va_addr:x}!")
|
||||
|
||||
print(f"Freeing {size} bytes at {ptr.va_addr:x}, pattern verified: {pattern:02x}")
|
||||
print(f"Freeing {size} bytes at {vm.va_addr:x}, pattern verified: {pattern:02x}")
|
||||
self.alloc_payload -= size
|
||||
self.d.mm.vfree(ptr)
|
||||
self.d.mm.vfree(vm)
|
||||
del self.allocations[ptr]
|
||||
return True
|
||||
|
||||
|
||||
16
test/external/external_test_am.py
vendored
16
test/external/external_test_am.py
vendored
@@ -153,5 +153,21 @@ class TestAMPageTable(unittest.TestCase):
|
||||
mm0.map_range(0x1000000, 2 << 20, paddrs=[(0x10000, 2 << 20)])
|
||||
mm0.unmap_range(0x1000000, 2 << 20)
|
||||
|
||||
def test_frag_size(self):
|
||||
mm0 = self.d[0].mm
|
||||
|
||||
def must_cover_checker(va, sz):
|
||||
ans = (1 << (mm0._frag_size(va=va, sz=sz, must_cover=True) + 12))
|
||||
assert va % ans == 0 and sz % ans == 0 and (va % (2 * ans) != 0 or sz % (2 * ans) != 0), f"va {va:#x} sz {sz:#x} ans {ans:#x}"
|
||||
|
||||
def not_cover_checker(va, sz):
|
||||
ans = (1 << (mm0._frag_size(va=va, sz=sz, must_cover=False) + 12))
|
||||
assert va % ans == 0 and ans <= sz and (va % (2 * ans) != 0 or (2 * ans) > sz), f"va {va:#x} sz {sz:#x} ans {ans:#x}"
|
||||
|
||||
for va, sz in [(0x0, 0x1000), (0x1000, 0x2000), (0x1000, 0x3000), (0x2000, 0x2000), (0x4000, 0x8000), (0x8000, 0x4000), (0x10000, 0x4000),
|
||||
(0x0, 0x4000), (0x10000, 0x4000), (0x10000, 0x40000), (0x10001000, 0x40000), (0x100001000, 0x3000)]:
|
||||
must_cover_checker(va, sz)
|
||||
not_cover_checker(va, sz)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
21
test/external/external_test_onnx_ops.py
vendored
21
test/external/external_test_onnx_ops.py
vendored
@@ -27,6 +27,27 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
outputs = ["out"]
|
||||
self.helper_test_single_op("Reshape", inputs, attributes, outputs)
|
||||
|
||||
def test_conv(self):
|
||||
# test VALID auto_pad
|
||||
inputs = {
|
||||
"x": np.random.randn(1, 3, 384, 384).astype(np.float32),
|
||||
"w": np.random.randn(1152, 3, 14, 14).astype(np.float32),
|
||||
"b": np.random.randn(1152).astype(np.float32)
|
||||
}
|
||||
attributes = {'auto_pad': 'VALID', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (14, 14), 'strides': (14, 14)}
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("Conv", inputs, attributes, outputs, atol=1e-4)
|
||||
|
||||
def test_gather(self):
|
||||
# test const negative indices
|
||||
inputs = {
|
||||
"input": np.random.randn(1, 3, 3).astype(np.float32),
|
||||
"indices": np.array(-2, dtype=np.int64),
|
||||
}
|
||||
attributes = {'axis': 1}
|
||||
outputs = ["y"]
|
||||
self.helper_test_single_op("Gather", inputs, attributes, outputs)
|
||||
|
||||
def test_quantize_linear(self):
|
||||
test_cases = [
|
||||
{"test_case": "round_half_to_even", "qdtype": np.int8, "qzero_point": 0, "x": [-1.5, -0.5, 0.5, 1.5], "scale": 1.0},
|
||||
|
||||
4
test/external/speed_v_theoretical.py
vendored
4
test/external/speed_v_theoretical.py
vendored
@@ -91,11 +91,11 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
|
||||
# theoretical is nv_tflops=165, amd_tflops=123
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=115, amd_tflops=80)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=70)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv_tflops=125, amd_tflops=65)
|
||||
|
||||
# theoretical is nv_gbs=1008, amd_gbs=960
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=760)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=750)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -5,9 +5,8 @@ import tinygrad.runtime.autogen.amd_gpu as amd_gpu
|
||||
|
||||
SDMA_MAX_COPY_SIZE = 0x400000
|
||||
|
||||
BASE_ADDR = 0x00001260
|
||||
PACKET3_SET_SH_REG_START = 0x2c00
|
||||
SUB = PACKET3_SET_SH_REG_START - BASE_ADDR
|
||||
SUB = PACKET3_SET_SH_REG_START - amd_gpu.GC_BASE__INST0_SEG0
|
||||
|
||||
regCOMPUTE_PGM_LO = 0x1bac - SUB
|
||||
regCOMPUTE_USER_DATA_0 = 0x1be0 - SUB
|
||||
|
||||
@@ -343,6 +343,11 @@ class TestDiskTensor(unittest.TestCase):
|
||||
on_dev = t.to(Device.DEFAULT).realize()
|
||||
np.testing.assert_equal(on_dev.numpy(), t.numpy())
|
||||
|
||||
@unittest.skipUnless(OSX, "seems to only be an issue on macOS with file size >2 GiB")
|
||||
def test_copy_to_cpu_not_truncated(self):
|
||||
with open((fn:=temp("dt_copy_to_cpu_not_truncated")), "wb") as f: f.write(b'\x01' * (size := int(2 * 1024**3)) + (test := b"test"))
|
||||
x = Tensor.empty(size + len(test), dtype=dtypes.uint8, device=f"disk:{fn}").to("CPU").realize()
|
||||
assert x[size:].data().tobytes() == test
|
||||
|
||||
class TestPathTensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -359,33 +359,28 @@ fix_kernel_ops = PatternMatcher([
|
||||
# remove CONTIGUOUS/DEVICE
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
||||
# remove unmasked valid
|
||||
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
|
||||
# no ImageDType after load
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
|
||||
])
|
||||
|
||||
# TODO: replace this with the KERNEL UOp
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
|
||||
def schedule_uop(sink:UOp, var_vals:dict[Variable, int]) -> ScheduleItem:
|
||||
assert sink.op is Ops.ASSIGN and sink.src[1].op is Ops.KERNEL, f"{sink} must be ASSIGN"
|
||||
def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
assert k.op is Ops.KERNEL, f"kernel isn't kernel, it's {k}"
|
||||
# substitute kernel sources for the target buffer
|
||||
ast = sink.src[1].arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in sink.src[1].src if s.op is Ops.ASSIGN}).sink()
|
||||
ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink()
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[sink.buf_uop], bottom_up=True)
|
||||
ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
# create subbuffer
|
||||
# create subbuffer (TODO: this does not belong here)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
return ScheduleItem(ast, tuple(dedup([x.buffer for x in bufs])), sink.src[1].arg.metadata)
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -395,6 +390,12 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
|
||||
# **** schedule creation and toposort
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
|
||||
@track_rewrites(name_fxn=lambda r: f"Schedule {pluralize('Kernel', len(r[0]))}"+(f" (with_{pluralize('Var', len(r[1]))})" if len(r[1]) != 0 else ""))
|
||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||
# merge_views + sym
|
||||
@@ -408,7 +409,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
realize_map = group_realizes(sink)
|
||||
# map tensor metadata to simplified ops
|
||||
ops_metadata = {v:k.metadata for k,v in tensor_map.items() if k.base.op not in {Ops.CONST, Ops.DEVICE} and isinstance(k.metadata, Metadata)}
|
||||
# create kernels
|
||||
# create_kernels
|
||||
kernel_map = graph_rewrite_map(sink, create_kernels, ctx=KernelContext(realize_map, ops_metadata), bottom_up=True)
|
||||
sched_sink = kernel_map[sink]
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
@@ -459,7 +460,9 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
var_vals: dict[Variable, int] = {}
|
||||
while queue:
|
||||
u = queue.popleft()
|
||||
schedule.append(schedule_uop(u, var_vals))
|
||||
# TODO: move this to create_kernels
|
||||
k = fix_kernel_ast(u.src[1], var_vals)
|
||||
schedule.append(ScheduleItem(k.arg.ast, tuple(s.buf_uop.buffer for s in k.src), k.arg.metadata))
|
||||
# increment the refcount of the target buf (this is required by the JIT and memory planner)
|
||||
u.buf_uop.buffer.ref(1)
|
||||
for x in children.get(u, []):
|
||||
|
||||
@@ -82,7 +82,7 @@ class LARS(Optimizer):
|
||||
if self.tcoef != 0:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = g.square().sum().sqrt()
|
||||
r = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
||||
r:Tensor|float = (r1 > 0).where((r2 > 0).where(self.tcoef * r1 / (r2 + self.wd * r1), 1.0), 1.0)
|
||||
else: r = 1.0
|
||||
g = g + self.wd * t.detach()
|
||||
# classic momentum does post learning rate update
|
||||
@@ -141,7 +141,7 @@ class LAMB(Optimizer):
|
||||
if not self.adam:
|
||||
r1 = t.detach().square().sum().sqrt()
|
||||
r2 = up.square().sum().sqrt()
|
||||
r = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
|
||||
else:
|
||||
r = 1.0
|
||||
t.assign((t.detach() - self.lr * r * up).cast(t.dtype))
|
||||
|
||||
@@ -513,6 +513,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return dsrcs[0]._device if len(dsrcs:=[x for x in self.src if x._device is not None]) != 0 else None
|
||||
@property
|
||||
def buf_uop(self) -> UOp:
|
||||
if self.op is Ops.BUFFER: return self
|
||||
assert self.op is Ops.ASSIGN, f"must be ASSIGN {self.op}"
|
||||
return self.src[0].base
|
||||
@property
|
||||
|
||||
1319
tinygrad/runtime/autogen/am/hdp_6_0_0.py
Normal file
1319
tinygrad/runtime/autogen/am/hdp_6_0_0.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -22,8 +22,8 @@ WAIT_REG_MEM_FUNCTION_GEQ = 5 # >=
|
||||
|
||||
COMPUTE_SHADER_EN, FORCE_START_AT_000, CS_W32_EN = (1 << 0), (1 << 2), (1 << 15)
|
||||
|
||||
def gfxreg(reg): return reg + 0x00001260 - amd_gpu.PACKET3_SET_SH_REG_START
|
||||
def nbioreg(reg): return reg + 0x00000d20 # NBIO_BASE__INST0_SEG2
|
||||
def gfxreg(reg): return reg + amd_gpu.GC_BASE__INST0_SEG0 - amd_gpu.PACKET3_SET_SH_REG_START
|
||||
def nbioreg(reg): return reg + amd_gpu.NBIO_BASE__INST0_SEG2
|
||||
|
||||
class AMDSignal(HCQSignal):
|
||||
def __init__(self, base_addr:int|None=None, **kwargs):
|
||||
@@ -148,7 +148,7 @@ class AMDComputeQueue(HWQueue):
|
||||
for i, value in enumerate(cmds): dev.compute_queue.ring[(dev.compute_queue.put_value + i) % len(dev.compute_queue.ring)] = value
|
||||
|
||||
dev.compute_queue.put_value += len(cmds)
|
||||
dev.compute_queue.signal_doorbell()
|
||||
dev.compute_queue.signal_doorbell(dev)
|
||||
|
||||
class AMDCopyQueue(HWQueue):
|
||||
def __init__(self, max_copy_size=0x40000000):
|
||||
@@ -232,7 +232,7 @@ class AMDCopyQueue(HWQueue):
|
||||
dev.sdma_queue.ring[0:rem_packet_cnt] = array.array('I', cmds[tail_blit_dword:])
|
||||
dev.sdma_queue.put_value += rem_packet_cnt * 4
|
||||
|
||||
dev.sdma_queue.signal_doorbell()
|
||||
dev.sdma_queue.signal_doorbell(dev)
|
||||
|
||||
class AMDProgram(HCQProgram):
|
||||
def __init__(self, dev:AMDDevice, name:str, lib:bytes):
|
||||
@@ -242,26 +242,26 @@ class AMDProgram(HCQProgram):
|
||||
image, sections, _ = elf_loader(self.lib)
|
||||
self.lib_gpu = self.dev.allocator.alloc(round_up(image.nbytes, 0x1000), BufferSpec(cpu_access=True, nolru=True))
|
||||
ctypes.memmove(self.lib_gpu.va_addr, mv_address(image), image.nbytes)
|
||||
|
||||
entry_point = min(sh.header.sh_addr for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS and sh.header.sh_flags & libc.SHF_ALLOC)
|
||||
self.group_segment_size = image[entry_point:entry_point+4].cast("I")[0]
|
||||
self.private_segment_size = image[entry_point+4:entry_point+8].cast("I")[0]
|
||||
self.kernargs_segment_size = image[entry_point+8:entry_point+12].cast("I")[0]
|
||||
|
||||
rodata_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".rodata"), -1)
|
||||
text_entry = next((sh.header.sh_addr for sh in sections if sh.name == ".text"), -1)
|
||||
assert rodata_entry >= 0 and text_entry >= 0, ".text or .rodata section not found"
|
||||
self.group_segment_size = image[rodata_entry:rodata_entry+4].cast("I")[0]
|
||||
self.private_segment_size = image[rodata_entry+4:rodata_entry+8].cast("I")[0]
|
||||
self.kernargs_segment_size = image[rodata_entry+8:rodata_entry+12].cast("I")[0]
|
||||
lds_size = ((self.group_segment_size + 511) // 512) & 0x1FF
|
||||
if lds_size > (self.dev.dev_iface.props['lds_size_in_kb'] * 1024) // 512: raise RuntimeError("Too many resources requested: group_segment_size")
|
||||
|
||||
# Ensure scratch size
|
||||
self.dev._ensure_has_local_memory(self.private_segment_size)
|
||||
|
||||
code = hsa.amd_kernel_code_t.from_address(self.lib_gpu.va_addr + entry_point) # NOTE: this is wrong, it's not this object
|
||||
code = hsa.amd_kernel_code_t.from_address(self.lib_gpu.va_addr + rodata_entry) # NOTE: this is wrong, it's not this object
|
||||
assert code.kernel_code_properties & 0x400 == 0x400 # ENABLE_WAVEFRONT_SIZE32
|
||||
|
||||
# Set rsrc1.priv=1 on gfx11 to workaround cwsr.
|
||||
self.rsrc1: int = code.compute_pgm_rsrc1 | ((1 << 20) if 110000 <= self.dev.target < 120000 else 0)
|
||||
self.rsrc2: int = code.compute_pgm_rsrc2 | (lds_size << 15)
|
||||
self.prog_addr: int = self.lib_gpu.va_addr + entry_point + code.kernel_code_entry_byte_offset
|
||||
|
||||
self.prog_addr: int = self.lib_gpu.va_addr + rodata_entry + code.kernel_code_entry_byte_offset
|
||||
if code.kernel_code_entry_byte_offset == 0: self.prog_addr = self.lib_gpu.va_addr + text_entry
|
||||
# Some programs use hsa_kernel_dispatch_packet_t to read workgroup sizes during execution.
|
||||
# The packet is represented as a pointer and set up in SGPRs. Space for the packet is allocated as part of the kernel arguments.
|
||||
self.enable_dispatch_ptr: int = code.kernel_code_properties & hsa.AMD_KERNEL_CODE_PROPERTIES_ENABLE_SGPR_DISPATCH_PTR
|
||||
@@ -293,11 +293,14 @@ class AMDQueueDesc:
|
||||
doorbell: memoryview
|
||||
put_value: int = 0
|
||||
|
||||
def signal_doorbell(self):
|
||||
def signal_doorbell(self, dev):
|
||||
self.write_ptr[0] = self.put_value
|
||||
|
||||
# Ensure all prior writes are visible to the GPU.
|
||||
if CPUProgram.atomic_lib is not None: CPUProgram.atomic_lib.atomic_thread_fence(__ATOMIC_SEQ_CST:=5)
|
||||
|
||||
# Flush hdp if queue is in dev mem.
|
||||
if dev.driverless and getenv("AMD_ALLOC_QUEUE_DEV_MEM", 1): dev.dev_iface.adev.gmc.flush_hdp()
|
||||
self.doorbell[0] = self.put_value
|
||||
|
||||
class KFDIface:
|
||||
@@ -450,7 +453,8 @@ class PCIIface:
|
||||
HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/driver/unbind", os.O_WRONLY).write(self.pcibus)
|
||||
|
||||
supported_sizes = int(HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDONLY).read(), 16)
|
||||
HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDWR).write(str(supported_sizes.bit_length() - 1))
|
||||
try: HWInterface(f"/sys/bus/pci/devices/{self.pcibus}/resource0_resize", os.O_RDWR).write(str(supported_sizes.bit_length() - 1))
|
||||
except OSError as e: raise RuntimeError(f"Cannot resize BAR: {e}. Ensure the resizable BAR option is enabled on your system.") from e
|
||||
|
||||
# Try to init vfio. Use it if success.
|
||||
if PCIIface.vfio:
|
||||
@@ -580,7 +584,7 @@ class AMDDevice(HCQCompiled):
|
||||
sgrp_size_per_cu, lds_size_per_cu, hwreg_size_per_cu = 0x4000, 0x10000, 0x1000
|
||||
vgpr_size_per_cu = 0x60000 if self.target in {110000, 110001, 120000, 120001} else 0x40000
|
||||
wg_data_size = round_up((vgpr_size_per_cu + sgrp_size_per_cu + lds_size_per_cu + hwreg_size_per_cu) * (self.max_cu_id + 1), mmap.PAGESIZE)
|
||||
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE)
|
||||
ctl_stack_size = round_up(12 * (self.max_cu_id + 1) * (self.max_wave_id + 1) + 8 + 40, mmap.PAGESIZE) if self.target//10000 != 10 else 0x7000
|
||||
debug_memory_size = round_up((self.max_cu_id + 1) * (self.max_wave_id + 1) * 32, 64)
|
||||
|
||||
self.compute_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_COMPUTE, 0x800000, ctx_save_restore_size=wg_data_size + ctl_stack_size,
|
||||
|
||||
@@ -84,7 +84,8 @@ class DiskAllocator(Allocator):
|
||||
# OSX doesn't seem great at mmap, this is faster
|
||||
with io.FileIO(self.dev.fd, "a+b", closefd=False) as fo:
|
||||
fo.seek(src.offset)
|
||||
fo.readinto(dest)
|
||||
bytes_read = 0
|
||||
while (n := fo.readinto(dest[bytes_read:])) is not None and n > 0: bytes_read += n
|
||||
else:
|
||||
dest[:] = src._buf()
|
||||
|
||||
|
||||
@@ -175,6 +175,15 @@ class AMMemoryManager:
|
||||
self.pa_allocator = TLSFAllocator(vram_size - (64 << 20)) # per device
|
||||
self.root_page_table = AMPageTableEntry(self.adev, self.palloc(0x1000, zero=not self.adev.smi_dev, boot=True), lv=am.AMDGPU_VM_PDB1)
|
||||
|
||||
def _frag_size(self, va, sz, must_cover=True):
|
||||
"""
|
||||
Calculate the tlb fragment size for a given virtual address and size.
|
||||
If must_cover is True, the fragment size must cover the size, otherwise the biggest fragment size that fits the size is returned.
|
||||
Fragment 0 is 4KB, 1 is 8KB and so on.
|
||||
"""
|
||||
va_pwr2_div, sz_pwr2_div, sz_pwr2_max = va & -(va) if va > 0 else (1 << 63), sz & -(sz), (1 << (sz.bit_length() - 1))
|
||||
return (min(va_pwr2_div, sz_pwr2_div) if must_cover else min(va_pwr2_div, sz_pwr2_max)).bit_length() - 1 - 12
|
||||
|
||||
def map_range(self, vaddr:int, size:int, paddrs:list[tuple[int, int]], uncached=False, system=False, snooped=False) -> AMMapping:
|
||||
if AM_DEBUG >= 2: print(f"am {self.adev.devfmt}: mapping {vaddr=:#x} ({size=:#x})")
|
||||
|
||||
@@ -185,8 +194,8 @@ class AMMemoryManager:
|
||||
for off, pt, pte_idx, pte_cnt, pte_covers in ctx.next(psize):
|
||||
for pte_off in range(pte_cnt):
|
||||
assert pt.entries[pte_idx + pte_off] & am.AMDGPU_PTE_VALID == 0, f"PTE already mapped: {pt.entries[pte_idx + pte_off]:#x}"
|
||||
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers,
|
||||
uncached=uncached, system=system, snooped=snooped, frag=0 if pte_covers == 0x1000 else 0x9, valid=True)
|
||||
pt.set_entry(pte_idx + pte_off, paddr + off + pte_off * pte_covers, uncached=uncached, system=system, snooped=snooped,
|
||||
frag=self._frag_size(ctx.vaddr+off, pte_cnt * pte_covers), valid=True)
|
||||
|
||||
# Invalidate TLB after mappings.
|
||||
self.adev.gmc.flush_tlb(ip='GC', vmid=0)
|
||||
@@ -212,12 +221,21 @@ class AMMemoryManager:
|
||||
if contigous: paddrs = [(self.palloc(size, zero=True), size)]
|
||||
else:
|
||||
paddrs = []
|
||||
try:
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
||||
for _, _, _, seg_cnt, seg_size in ctx.next(size): paddrs += [(self.palloc(seg_size, zero=False), seg_size) for _ in range(seg_cnt)]
|
||||
except MemoryError:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise
|
||||
ctx = AMPageTableTraverseContext(self.adev, self.root_page_table, va, create_pts=True)
|
||||
for off, _, _, seg_cnt, seg_size in ctx.next(size):
|
||||
while seg_cnt > 0:
|
||||
# Try to allocate as long segment (power of 2) as possible
|
||||
cont_seg_sz, paddr = 1 << (self._frag_size(ctx.vaddr+off, seg_cnt*seg_size) + 12), None
|
||||
while cont_seg_sz >= seg_size:
|
||||
try: paddr = self.palloc(cont_seg_sz, zero=True)
|
||||
except MemoryError: cont_seg_sz //= 2
|
||||
else: break
|
||||
|
||||
if paddr is not None: paddrs += [(paddr, cont_seg_sz)]
|
||||
else:
|
||||
for paddr, _ in paddrs: self.pa_allocator.free(paddr)
|
||||
raise MemoryError(f"Failed to allocate contigous {cont_seg_sz=:#x} bytes (size={size:#x})")
|
||||
seg_cnt, off = seg_cnt - cont_seg_sz // seg_size, off + cont_seg_sz
|
||||
|
||||
return self.map_range(va, size, paddrs, uncached=uncached)
|
||||
|
||||
@@ -295,7 +313,7 @@ class AMDev:
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
|
||||
self.smu.set_clocks(level=-1) # last level, max perf.
|
||||
self.gfx.set_clockgating_state()
|
||||
for ip in [self.soc21, self.gfx]: ip.set_clockgating_state()
|
||||
self.reg("regSCRATCH_REG7").write(am_version)
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: boot done")
|
||||
|
||||
@@ -382,7 +400,9 @@ class AMDev:
|
||||
|
||||
def _build_regs(self):
|
||||
mods = [("MP0", self._ip_module("mp", am.MP0_HWIP)), ("NBIO", self._ip_module("nbio", am.NBIO_HWIP)), ("GC", self._ip_module("gc", am.GC_HWIP)),
|
||||
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP))]
|
||||
("MP1", mp_11_0), ("MMHUB", self._ip_module("mmhub", am.MMHUB_HWIP)), ("OSSSYS", self._ip_module("osssys", am.OSSSYS_HWIP)),
|
||||
("HDP", self._ip_module("hdp", am.HDP_HWIP))]
|
||||
|
||||
for base, module in mods:
|
||||
rpref = "mm" if base == "MP1" else "reg" # MP1 regs starts with mm
|
||||
reg_names: set[str] = set(k[len(rpref):] for k in module.__dict__.keys() if k.startswith(rpref) and not k.endswith("_BASE_IDX"))
|
||||
|
||||
@@ -7,11 +7,13 @@ class AM_IP:
|
||||
def __init__(self, adev): self.adev = adev
|
||||
def init(self): raise NotImplementedError("IP block init must be implemeted")
|
||||
def fini(self): pass
|
||||
def set_clockgating_state(self): pass
|
||||
|
||||
class AM_SOC21(AM_IP):
|
||||
def init(self):
|
||||
self.adev.regRCC_DEV0_EPF2_STRAP2.update(strap_no_soft_reset_dev0_f2=0x0)
|
||||
self.adev.regRCC_DEV0_EPF0_RCC_DOORBELL_APER_EN.write(0x1)
|
||||
def set_clockgating_state(self): self.adev.regHDP_MEM_POWER_CTRL.update(atomic_mem_power_ctrl_en=1, atomic_mem_power_ds_en=1)
|
||||
|
||||
class AM_GMC(AM_IP):
|
||||
def __init__(self, adev):
|
||||
@@ -34,7 +36,7 @@ class AM_GMC(AM_IP):
|
||||
|
||||
def init(self): self.init_hub("MM")
|
||||
|
||||
def flush_hdp(self): self.adev.regBIF_BX_PF0_GPU_HDP_FLUSH_REQ.write(0xffffffff)
|
||||
def flush_hdp(self): self.adev.wreg(self.adev.reg("regBIF_BX0_REMAP_HDP_MEM_FLUSH_CNTL").read() // 4, 0x0)
|
||||
def flush_tlb(self, ip:Literal["MM", "GC"], vmid, flush_type=0):
|
||||
self.flush_hdp()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, hashlib, weakref
|
||||
from contextlib import ContextDecorator
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex, ParamSpec, TypeVar
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
@@ -46,11 +46,10 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp]) -> None:
|
||||
|
||||
# **** Tensor helper functions ****
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None):
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:str|tuple[str, ...], arg=None) -> UOp:
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return UOp.multi(*[UOp.metaop(op, shape, dtype, d, arg) for d in device], axis=None)
|
||||
|
||||
|
||||
def _fromnp(x: 'np.ndarray') -> UOp: # type: ignore [name-defined] # noqa: F821
|
||||
ret = UOp.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY")
|
||||
# fake realize
|
||||
@@ -97,7 +96,7 @@ def _align_left(*shapes:tuple[sint, ...]) -> tuple[tuple[sint, ...], ...]:
|
||||
def _broadcast_shape(*shapes:tuple[sint, ...]) -> tuple[sint, ...]:
|
||||
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
|
||||
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]):
|
||||
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:tuple[int, ...]) -> Tensor:
|
||||
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
|
||||
values = values * mask
|
||||
for dim in axes: mask, values = functools.reduce(lambda x,y: (x[0]|y[0], y[0].where(y[1], x[1])), zip(mask.split(1, dim), values.split(1, dim)))
|
||||
@@ -293,8 +292,8 @@ class Tensor(SimpleMathTrait):
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
|
||||
buf = cast(UOp, cpu.lazydata).base.realized
|
||||
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
||||
buf = cpu.lazydata.base.realized
|
||||
assert buf is not None, f"{cpu.lazydata.base} was not realized"
|
||||
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
||||
|
||||
@@ -310,7 +309,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
||||
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
||||
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
|
||||
|
||||
def item(self) -> ConstType:
|
||||
"""
|
||||
@@ -376,7 +375,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None: ret.grad = self.grad.to(device)
|
||||
return ret
|
||||
|
||||
def to_(self, device:str|tuple[str, ...]|None):
|
||||
def to_(self, device:str|tuple[str, ...]|None) -> Tensor:
|
||||
"""
|
||||
Moves the tensor to the given device in place.
|
||||
"""
|
||||
@@ -398,7 +397,7 @@ class Tensor(SimpleMathTrait):
|
||||
mlb = self.lazydata.shard(devices, self._resolve_dim(axis) if axis is not None else None)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None):
|
||||
def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor:
|
||||
"""
|
||||
Shards the tensor across the given devices in place.
|
||||
"""
|
||||
@@ -415,7 +414,7 @@ class Tensor(SimpleMathTrait):
|
||||
# ***** creation entrypoint *****
|
||||
|
||||
@staticmethod
|
||||
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs):
|
||||
def _metaop(op, shape, device:str|tuple[str, ...]|None=None, dtype:DTypeLike|None=None, arg=None, **kwargs) -> Tensor:
|
||||
dtype = to_dtype(dtype) if dtype is not None else dtypes.default_float
|
||||
if isinstance(device, tuple):
|
||||
return Tensor(UOp.multi(*[UOp.metaop(op, shape, dtype, Device.canonicalize(d), arg) for d in device], axis=None),
|
||||
@@ -423,7 +422,7 @@ class Tensor(SimpleMathTrait):
|
||||
return Tensor(UOp.metaop(op, shape, dtype, Device.canonicalize(device), arg), device, dtype, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def empty(*shape, **kwargs):
|
||||
def empty(*shape, **kwargs) -> Tensor:
|
||||
"""
|
||||
Creates an empty tensor with the given shape.
|
||||
|
||||
@@ -468,7 +467,7 @@ class Tensor(SimpleMathTrait):
|
||||
_device_seeds: dict[str, Tensor] = {}
|
||||
_device_rng_counters: dict[str, Tensor] = {}
|
||||
@staticmethod
|
||||
def manual_seed(seed=0):
|
||||
def manual_seed(seed=0) -> None:
|
||||
"""
|
||||
Sets the seed for random operations.
|
||||
|
||||
@@ -486,7 +485,7 @@ class Tensor(SimpleMathTrait):
|
||||
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
||||
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor) -> Tensor:
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = x._apply_uop(UOp.threefry, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
@@ -1069,7 +1068,7 @@ class Tensor(SimpleMathTrait):
|
||||
if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}")
|
||||
X, pads = self, tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX)
|
||||
if mode == "constant":
|
||||
def _constant(x:Tensor,px,v):
|
||||
def _constant(x:Tensor,px,v) -> Tensor:
|
||||
return x._apply_uop(UOp.pad, arg=px) if v == 0 else (x._apply_uop(UOp.pad, arg=px)+Tensor.ones_like(x)._apply_uop(UOp.pad, arg=px).where(0,v))
|
||||
return _constant(X, pX, value) if all(resolve(p >= 0) for p in flatten(pX)) else \
|
||||
_constant(X.shrink(tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, X.shape))), pads, value)
|
||||
@@ -1464,7 +1463,7 @@ class Tensor(SimpleMathTrait):
|
||||
order[dim0], order[dim1] = order[dim1], order[dim0]
|
||||
return self.permute(order)
|
||||
|
||||
def flatten(self, start_dim=0, end_dim=-1):
|
||||
def flatten(self, start_dim=0, end_dim=-1) -> Tensor:
|
||||
"""
|
||||
Flattens the tensor by reshaping it into a one-dimensional tensor.
|
||||
If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` and ending with `end_dim` are flattened.
|
||||
@@ -1480,7 +1479,7 @@ class Tensor(SimpleMathTrait):
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]):
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
|
||||
"""
|
||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||
|
||||
@@ -1565,7 +1564,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = self._apply_uop(UOp.r, op=op, axis=axis)
|
||||
return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis))
|
||||
|
||||
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
|
||||
def sum(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Returns the sum of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1592,7 +1591,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(Ops.ADD, axis, keepdim)
|
||||
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret
|
||||
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None):
|
||||
def prod(self, axis:int|Sequence[int]|None=None, keepdim=False, acc_dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Returns the product of the elements of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1618,7 +1617,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(Ops.MUL, axis, keepdim)
|
||||
|
||||
def max(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
def max(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the maximum value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1641,9 +1640,9 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._reduce(Ops.MAX, axis, keepdim)
|
||||
|
||||
def _inverse(self): return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
||||
def _inverse(self) -> Tensor: return -self if self.is_floating_point() else ~self if dtypes.is_int(self.dtype) else self.logical_not()
|
||||
|
||||
def min(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
def min(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the minimum value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1666,7 +1665,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self._inverse().max(axis=axis, keepdim=keepdim)._inverse()
|
||||
|
||||
def any(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
def any(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Tests if any element evaluates to `True` along the specified axis or axes.
|
||||
|
||||
@@ -1688,7 +1687,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.bool().max(axis, keepdim)
|
||||
|
||||
def all(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
def all(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Tests if all element evaluates to `True` along the specified axis or axes.
|
||||
|
||||
@@ -1730,7 +1729,7 @@ class Tensor(SimpleMathTrait):
|
||||
is_nan_close = (self.isnan() & other.isnan()) & equal_nan
|
||||
return is_finite_close | is_infinite_close | is_nan_close
|
||||
|
||||
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False):
|
||||
def mean(self, axis:int|Sequence[int]|None=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the mean value of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1754,9 +1753,10 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
output_dtype = self.dtype if dtypes.is_float(self.dtype) else dtypes.float32
|
||||
numerator = self.cast(sum_acc_dtype(self.dtype)).sum(axis=axis, keepdim=keepdim)
|
||||
return numerator.div(prod([si for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])).cast(output_dtype)
|
||||
return numerator.div(prod([cast(int, si) for si, so in zip(self.shape, self.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])) \
|
||||
.cast(output_dtype)
|
||||
|
||||
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
def var(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
||||
"""
|
||||
Returns the variance of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1782,7 +1782,7 @@ class Tensor(SimpleMathTrait):
|
||||
n = prod([si for si, so in zip(self.shape, squares.sum(axis=axis, keepdim=True).shape) if resolve(si != so)])
|
||||
return squares.sum(axis=axis, keepdim=keepdim).div(smax([0, n-correction]))
|
||||
|
||||
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
def var_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Calculates the variance and mean over the dimensions specified by dim.
|
||||
Syntactic sugar around `Tensor.var` and `Tensor.mean` to match `torch.var_mean`.
|
||||
@@ -1799,7 +1799,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.var(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
def std(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> Tensor:
|
||||
"""
|
||||
Returns the standard deviation of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1823,7 +1823,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.var(axis, keepdim, correction).sqrt()
|
||||
|
||||
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1):
|
||||
def std_mean(self, axis:int|Sequence[int]|None=None, keepdim=False, correction=1) -> tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Calculates the standard deviation and mean over the dimensions specified by dim.
|
||||
Syntactic sugar around `Tensor.std` and `Tensor.mean` to match `torch.std_mean`.
|
||||
@@ -1840,13 +1840,13 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None):
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
def softmax(self, axis=-1, dtype:DTypeLike|None=None):
|
||||
def softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies the softmax function to the tensor along the specified axis.
|
||||
|
||||
@@ -1869,7 +1869,7 @@ class Tensor(SimpleMathTrait):
|
||||
_, e, ss = self._softmax(axis, dtype)
|
||||
return e.div(ss)
|
||||
|
||||
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None):
|
||||
def log_softmax(self, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
|
||||
"""
|
||||
Applies the log-softmax function to the tensor along the specified axis.
|
||||
|
||||
@@ -1892,7 +1892,7 @@ class Tensor(SimpleMathTrait):
|
||||
m, _, ss = self._softmax(axis, dtype)
|
||||
return m - ss.log()
|
||||
|
||||
def logsumexp(self, axis=None, keepdim=False):
|
||||
def logsumexp(self, axis=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Computes the log-sum-exp of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1919,7 +1919,7 @@ class Tensor(SimpleMathTrait):
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + m.squeeze(axis)
|
||||
|
||||
def logcumsumexp(self, axis=0):
|
||||
def logcumsumexp(self, axis=0) -> Tensor:
|
||||
"""
|
||||
Computes the log-cumsum-exp of the tensor along the specified axis or axes.
|
||||
|
||||
@@ -1954,7 +1954,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = ((x_expand - x_cummax).exp() * mask).sum(-1).log() + x_cummax.squeeze(-1)
|
||||
return ret.reshape(*x.shape).transpose(-1, axis)
|
||||
|
||||
def argmax(self, axis=None, keepdim=False):
|
||||
def argmax(self, axis=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value of the tensor along the specified axis.
|
||||
|
||||
@@ -1981,7 +1981,7 @@ class Tensor(SimpleMathTrait):
|
||||
idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
|
||||
return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32)
|
||||
|
||||
def argmin(self, axis=None, keepdim=False):
|
||||
def argmin(self, axis=None, keepdim=False) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the minimum value of the tensor along the specified axis.
|
||||
|
||||
@@ -2353,7 +2353,7 @@ class Tensor(SimpleMathTrait):
|
||||
ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0), value=identity_element(op, self.dtype)).unflatten(-1, (-1, SPLIT))._cumalu(-1, op)
|
||||
base = ret[..., -1]._cumalu(-1, op, _include_initial=True)
|
||||
base = base.unsqueeze(-1).expand(*base.shape, ret.shape[-1])
|
||||
def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
def fix(x: Tensor) -> Tensor: return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1)
|
||||
return fix(ret) + fix(base) if op is Ops.ADD else fix(ret).maximum(fix(base))
|
||||
|
||||
def cumsum(self, axis:int=0) -> Tensor:
|
||||
@@ -2565,7 +2565,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** unary ops *****
|
||||
|
||||
def logical_not(self):
|
||||
def logical_not(self) -> Tensor:
|
||||
"""
|
||||
Computes the logical NOT of the tensor element-wise.
|
||||
|
||||
@@ -2574,7 +2574,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.bool)._apply_broadcasted_uop(UOp.ne, True)
|
||||
def neg(self):
|
||||
|
||||
def neg(self) -> Tensor:
|
||||
"""
|
||||
Negates the tensor element-wise.
|
||||
|
||||
@@ -2583,17 +2584,20 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self*-1 if self.dtype != dtypes.bool else self.logical_not()
|
||||
def contiguous(self):
|
||||
|
||||
def contiguous(self) -> Tensor:
|
||||
"""
|
||||
Returns a contiguous tensor.
|
||||
"""
|
||||
return self._apply_uop(UOp.contiguous)
|
||||
def contiguous_backward(self):
|
||||
|
||||
def contiguous_backward(self) -> Tensor:
|
||||
"""
|
||||
Inserts a contiguous operation in the backward pass.
|
||||
"""
|
||||
return self._apply_uop(UOp.contiguous_backward)
|
||||
def log(self):
|
||||
|
||||
def log(self) -> Tensor:
|
||||
"""
|
||||
Computes the natural logarithm element-wise.
|
||||
|
||||
@@ -2604,7 +2608,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.log2()*math.log(2)
|
||||
def log2(self):
|
||||
|
||||
def log2(self) -> Tensor:
|
||||
"""
|
||||
Computes the base-2 logarithm element-wise.
|
||||
|
||||
@@ -2615,7 +2620,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.log2)
|
||||
def exp(self):
|
||||
|
||||
def exp(self) -> Tensor:
|
||||
"""
|
||||
Computes the exponential function element-wise.
|
||||
|
||||
@@ -2626,7 +2632,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.mul(1/math.log(2)).exp2()
|
||||
def exp2(self):
|
||||
|
||||
def exp2(self) -> Tensor:
|
||||
"""
|
||||
Computes the base-2 exponential function element-wise.
|
||||
|
||||
@@ -2637,7 +2644,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.exp2)
|
||||
def relu(self):
|
||||
|
||||
def relu(self) -> Tensor:
|
||||
"""
|
||||
Applies the Rectified Linear Unit (ReLU) function element-wise.
|
||||
|
||||
@@ -2649,7 +2657,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self>0).where(self, 0)
|
||||
|
||||
def sigmoid(self):
|
||||
def sigmoid(self) -> Tensor:
|
||||
"""
|
||||
Applies the Sigmoid function element-wise.
|
||||
|
||||
@@ -2661,7 +2669,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (1 + (self * (-1/math.log(2))).exp2()).reciprocal()
|
||||
|
||||
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5):
|
||||
def hardsigmoid(self, alpha:float=1/6, beta:float=0.5) -> Tensor:
|
||||
"""
|
||||
Applies the Hardsigmoid function element-wise.
|
||||
NOTE: default `alpha` and `beta` values is taken from torch
|
||||
@@ -2675,7 +2683,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (alpha * self + beta).relu() - (alpha * self + beta - 1).relu()
|
||||
|
||||
def sqrt(self):
|
||||
def sqrt(self) -> Tensor:
|
||||
"""
|
||||
Computes the square root of the tensor element-wise.
|
||||
|
||||
@@ -2684,7 +2692,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sqrt)
|
||||
def rsqrt(self):
|
||||
|
||||
def rsqrt(self) -> Tensor:
|
||||
"""
|
||||
Computes the reciprocal of the square root of the tensor element-wise.
|
||||
|
||||
@@ -2693,7 +2702,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.sqrt().reciprocal()
|
||||
def sin(self):
|
||||
|
||||
def sin(self) -> Tensor:
|
||||
"""
|
||||
Computes the sine of the tensor element-wise.
|
||||
|
||||
@@ -2702,7 +2712,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(least_upper_float(self.dtype))._apply_uop(UOp.sin)
|
||||
def cos(self):
|
||||
|
||||
def cos(self) -> Tensor:
|
||||
"""
|
||||
Computes the cosine of the tensor element-wise.
|
||||
|
||||
@@ -2711,7 +2722,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return ((math.pi/2)-self).sin()
|
||||
def tan(self):
|
||||
|
||||
def tan(self) -> Tensor:
|
||||
"""
|
||||
Computes the tangent of the tensor element-wise.
|
||||
|
||||
@@ -2721,7 +2733,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.sin() / self.cos()
|
||||
|
||||
def asin(self):
|
||||
def asin(self) -> Tensor:
|
||||
"""
|
||||
Computes the inverse sine (arcsine) of the tensor element-wise.
|
||||
|
||||
@@ -2734,7 +2746,7 @@ class Tensor(SimpleMathTrait):
|
||||
x = math.pi / 2 - (1.0 - self.abs()).sqrt() * polyN(self.abs(), coefficients)
|
||||
return self.sign() * x
|
||||
|
||||
def acos(self):
|
||||
def acos(self) -> Tensor:
|
||||
"""
|
||||
Computes the inverse cosine (arccosine) of the tensor element-wise.
|
||||
|
||||
@@ -2744,7 +2756,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return math.pi / 2 - self.asin()
|
||||
|
||||
def atan(self):
|
||||
def atan(self) -> Tensor:
|
||||
"""
|
||||
Computes the inverse tangent (arctan) of the tensor element-wise.
|
||||
|
||||
@@ -2765,6 +2777,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.cast(dtypes.int32).cast(self.dtype)
|
||||
|
||||
def ceil(self: Tensor) -> Tensor:
|
||||
"""
|
||||
Rounds the tensor element-wise towards positive infinity.
|
||||
@@ -2774,6 +2787,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return (self > (b := self.trunc())).where(b+1, b)
|
||||
|
||||
def floor(self: Tensor) -> Tensor:
|
||||
"""
|
||||
Rounds the tensor element-wise towards negative infinity.
|
||||
@@ -2783,6 +2797,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return (self < (b := self.trunc())).where(b-1, b)
|
||||
|
||||
def round(self: Tensor) -> Tensor:
|
||||
"""
|
||||
Rounds the tensor element-wise with rounding half to even.
|
||||
@@ -2802,6 +2817,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return (self == float("inf")) * detect_positive + (self == float("-inf")) * detect_negative
|
||||
|
||||
def isnan(self:Tensor) -> Tensor:
|
||||
"""
|
||||
Checks the tensor element-wise to return True where the element is NaN, otherwise returns False
|
||||
@@ -2811,6 +2827,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self != self
|
||||
|
||||
def isfinite(self:Tensor) -> Tensor:
|
||||
"""
|
||||
Checks the tensor element-wise to return True where the element is finite, otherwise returns False
|
||||
@@ -2834,7 +2851,7 @@ class Tensor(SimpleMathTrait):
|
||||
return (self+(((end - self).cast(dtypes.int8) * w_i + (1<<W_PREC-1)).cast(dtypes.uint16) >> W_PREC)).cast(dtypes.uint8)
|
||||
return self + (end - self) * weight
|
||||
|
||||
def square(self):
|
||||
def square(self) -> Tensor:
|
||||
"""
|
||||
Squares the tensor element-wise.
|
||||
Equivalent to `self*self`.
|
||||
@@ -2844,7 +2861,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self*self
|
||||
def clamp(self, min_=None, max_=None):
|
||||
|
||||
def clamp(self, min_=None, max_=None) -> Tensor:
|
||||
"""
|
||||
Clips (clamps) the values in the tensor between `min_` and `max_` element-wise.
|
||||
If `min_` is `None`, there is no lower bound. If `max_` is None, there is no upper bound.
|
||||
@@ -2856,12 +2874,14 @@ class Tensor(SimpleMathTrait):
|
||||
if min_ is None and max_ is None: raise RuntimeError("at least one of 'min_' or 'max_' must not be None")
|
||||
ret = self.maximum(min_) if min_ is not None else self
|
||||
return ret.minimum(max_) if max_ is not None else ret
|
||||
def clip(self, min_=None, max_=None):
|
||||
|
||||
def clip(self, min_=None, max_=None) -> Tensor:
|
||||
"""
|
||||
Alias for `Tensor.clamp`.
|
||||
"""
|
||||
return self.clamp(min_, max_)
|
||||
def sign(self):
|
||||
|
||||
def sign(self) -> Tensor:
|
||||
"""
|
||||
Returns the sign of the tensor element-wise.
|
||||
|
||||
@@ -2870,7 +2890,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self.ne(0).where((self<0).where(self.full_like(-1), self.full_like(1)), self.full_like(0)) + self*0
|
||||
def abs(self):
|
||||
|
||||
def abs(self) -> Tensor:
|
||||
"""
|
||||
Computes the absolute value of the tensor element-wise.
|
||||
|
||||
@@ -2879,7 +2900,8 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
return self * self.sign()
|
||||
def reciprocal(self):
|
||||
|
||||
def reciprocal(self) -> Tensor:
|
||||
"""
|
||||
Compute `1/x` element-wise.
|
||||
|
||||
@@ -2891,7 +2913,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** activation functions *****
|
||||
|
||||
def elu(self, alpha=1.0):
|
||||
def elu(self, alpha=1.0) -> Tensor:
|
||||
"""
|
||||
Applies the Exponential Linear Unit (ELU) function element-wise.
|
||||
|
||||
@@ -2904,7 +2926,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.relu() - alpha*(1-self.exp()).relu()
|
||||
|
||||
def celu(self, alpha=1.0):
|
||||
def celu(self, alpha=1.0) -> Tensor:
|
||||
"""
|
||||
Applies the Continuously differentiable Exponential Linear Unit (CELU) function element-wise.
|
||||
|
||||
@@ -2917,7 +2939,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.maximum(0) + (alpha * ((self / alpha).exp() - 1)).minimum(0)
|
||||
|
||||
def selu(self, alpha=1.67326, gamma=1.0507):
|
||||
def selu(self, alpha=1.67326, gamma=1.0507) -> Tensor:
|
||||
"""
|
||||
Applies the Scaled Exponential Linear Unit (SELU) function element-wise.
|
||||
|
||||
@@ -2930,7 +2952,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return gamma * (self >= 0).detach().where(self, alpha * (self.exp() - 1))
|
||||
|
||||
def swish(self):
|
||||
def swish(self) -> Tensor:
|
||||
"""
|
||||
See `.silu()`
|
||||
|
||||
@@ -2942,7 +2964,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self * self.sigmoid()
|
||||
|
||||
def silu(self):
|
||||
def silu(self) -> Tensor:
|
||||
"""
|
||||
Applies the Sigmoid Linear Unit (SiLU) function element-wise.
|
||||
|
||||
@@ -2955,7 +2977,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.swish() # The SiLU function is also known as the swish function.
|
||||
|
||||
def relu6(self):
|
||||
def relu6(self) -> Tensor:
|
||||
"""
|
||||
Applies the ReLU6 function element-wise.
|
||||
|
||||
@@ -2968,7 +2990,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.relu() - (self-6).relu()
|
||||
|
||||
def hardswish(self):
|
||||
def hardswish(self) -> Tensor:
|
||||
"""
|
||||
Applies the Hardswish function element-wise.
|
||||
|
||||
@@ -2981,7 +3003,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self * (self+3).relu6() * (1/6)
|
||||
|
||||
def tanh(self):
|
||||
def tanh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Hyperbolic Tangent (tanh) function element-wise.
|
||||
|
||||
@@ -2993,7 +3015,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return 2.0 * ((2.0 * self).sigmoid()) - 1.0
|
||||
|
||||
def sinh(self):
|
||||
def sinh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Hyperbolic Sine (sinh) function element-wise.
|
||||
|
||||
@@ -3005,7 +3027,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self.exp() - self.neg().exp()) / 2
|
||||
|
||||
def cosh(self):
|
||||
def cosh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Hyperbolic Cosine (cosh) function element-wise.
|
||||
|
||||
@@ -3017,7 +3039,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self.exp() + self.neg().exp()) / 2
|
||||
|
||||
def atanh(self):
|
||||
def atanh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Inverse Hyperbolic Tangent (atanh) function element-wise.
|
||||
|
||||
@@ -3029,7 +3051,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return ((1 + self)/(1 - self)).log() / 2
|
||||
|
||||
def asinh(self):
|
||||
def asinh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Inverse Hyperbolic Sine (asinh) function element-wise.
|
||||
|
||||
@@ -3041,7 +3063,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self + (self.square() + 1).sqrt()).log()
|
||||
|
||||
def acosh(self):
|
||||
def acosh(self) -> Tensor:
|
||||
"""
|
||||
Applies the Inverse Hyperbolic Cosine (acosh) function element-wise.
|
||||
|
||||
@@ -3053,7 +3075,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self + (self.square() - 1).sqrt()).log()
|
||||
|
||||
def hardtanh(self, min_val=-1, max_val=1):
|
||||
def hardtanh(self, min_val=-1, max_val=1) -> Tensor:
|
||||
"""
|
||||
Applies the Hardtanh function element-wise.
|
||||
|
||||
@@ -3065,7 +3087,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self.clip(min_val, max_val)
|
||||
|
||||
def erf(self):
|
||||
def erf(self) -> Tensor:
|
||||
"""
|
||||
Applies error function element-wise.
|
||||
|
||||
@@ -3079,7 +3101,7 @@ class Tensor(SimpleMathTrait):
|
||||
t = 1.0 / (1.0 + 0.3275911 * self.abs())
|
||||
return self.sign() * (1.0 - t * polyN(t, [1.061405429, -1.453152027, 1.421413741, -0.284496736, 0.254829592]) * (-self.square()).exp())
|
||||
|
||||
def gelu(self):
|
||||
def gelu(self) -> Tensor:
|
||||
"""
|
||||
Applies the Gaussian Error Linear Unit (GELU) function element-wise.
|
||||
|
||||
@@ -3092,7 +3114,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return 0.5 * self * (1 + (math.sqrt(2 / math.pi) * (self + 0.044715 * self ** 3)).tanh())
|
||||
|
||||
def quick_gelu(self):
|
||||
def quick_gelu(self) -> Tensor:
|
||||
"""
|
||||
Applies the Sigmoid GELU approximation element-wise.
|
||||
|
||||
@@ -3104,7 +3126,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self * (self * 1.702).sigmoid()
|
||||
|
||||
def leaky_relu(self, neg_slope=0.01):
|
||||
def leaky_relu(self, neg_slope=0.01) -> Tensor:
|
||||
"""
|
||||
Applies the Leaky ReLU function element-wise.
|
||||
|
||||
@@ -3119,7 +3141,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (self<0).where(neg_slope*self, self)
|
||||
|
||||
def mish(self):
|
||||
def mish(self) -> Tensor:
|
||||
"""
|
||||
Applies the Mish function element-wise.
|
||||
|
||||
@@ -3132,7 +3154,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return self * self.softplus().tanh()
|
||||
|
||||
def softplus(self, beta=1):
|
||||
def softplus(self, beta=1) -> Tensor:
|
||||
"""
|
||||
Applies the Softplus function element-wise.
|
||||
|
||||
@@ -3144,7 +3166,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
return (1/beta) * (1 + (self*beta).exp()).log()
|
||||
|
||||
def softsign(self):
|
||||
def softsign(self) -> Tensor:
|
||||
"""
|
||||
Applies the Softsign function element-wise.
|
||||
|
||||
@@ -3360,7 +3382,7 @@ class Tensor(SimpleMathTrait):
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
||||
|
||||
def lshift(self, x:int):
|
||||
def lshift(self, x:int) -> Tensor:
|
||||
"""
|
||||
Computes left arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
||||
Equivalent to `self << x`.
|
||||
@@ -3372,7 +3394,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert dtypes.is_unsigned(self.dtype) and isinstance(x, int) and x >= 0, f"not supported {self.dtype=} {x=}"
|
||||
return self.mul(2 ** x)
|
||||
|
||||
def rshift(self, x:int):
|
||||
def rshift(self, x:int) -> Tensor:
|
||||
"""
|
||||
Computes right arithmetic shift of `self` by `x` bits. `self` must have unsigned dtype.
|
||||
Equivalent to `self >> x`.
|
||||
@@ -3434,7 +3456,7 @@ class Tensor(SimpleMathTrait):
|
||||
t, x = self._broadcasted(x)
|
||||
return t._inverse().maximum(x._inverse())._inverse()
|
||||
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint):
|
||||
def where(self:Tensor, x:Tensor|ConstType|sint, y:Tensor|ConstType|sint) -> Tensor:
|
||||
"""
|
||||
Return a tensor of elements selected from either `x` or `y`, depending on `self`.
|
||||
`output_i = x_i if self_i else y_i`.
|
||||
@@ -3458,7 +3480,7 @@ class Tensor(SimpleMathTrait):
|
||||
cond, y = cond._broadcasted(y, match_dtype=False)
|
||||
return cond.cast(dtypes.bool)._apply_uop(UOp.where, *x._broadcasted(y))
|
||||
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType): return mask.where(value, self)
|
||||
def masked_fill(self:Tensor, mask:Tensor, value:Tensor|ConstType) -> Tensor: return mask.where(value, self)
|
||||
|
||||
def copysign(self, other) -> Tensor:
|
||||
"""
|
||||
@@ -3503,7 +3525,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** functional nn ops *****
|
||||
|
||||
def linear(self, weight:Tensor, bias:Tensor|None=None):
|
||||
def linear(self, weight:Tensor, bias:Tensor|None=None) -> Tensor:
|
||||
"""
|
||||
Applies a linear transformation to `self` using `weight` and `bias`.
|
||||
|
||||
@@ -3519,7 +3541,7 @@ class Tensor(SimpleMathTrait):
|
||||
x = self.mul(weight) if len(weight.shape) == 1 else self.dot(weight)
|
||||
return x.add(bias) if bias is not None else x
|
||||
|
||||
def sequential(self, ll:list[Callable[[Tensor], Tensor]]):
|
||||
def sequential(self, ll:list[Callable[[Tensor], Tensor]]) -> Tensor:
|
||||
"""
|
||||
Applies a sequence of functions to `self` chaining the output of each function to the input of the next.
|
||||
|
||||
@@ -3592,7 +3614,7 @@ class Tensor(SimpleMathTrait):
|
||||
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
# helper function commonly used for indexing
|
||||
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
||||
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1) -> Tensor:
|
||||
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
|
||||
offset = self.ndim - self._resolve_dim(dim) - 1
|
||||
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
||||
@@ -3821,7 +3843,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# ***** cast ops *****
|
||||
|
||||
def llvm_bf16_cast(self, dtype:DTypeLike):
|
||||
def llvm_bf16_cast(self, dtype:DTypeLike) -> Tensor:
|
||||
# hack for devices that don't support bfloat16
|
||||
assert self.dtype == dtypes.bfloat16
|
||||
return self.to("LLVM").cast(dtype)
|
||||
@@ -4011,8 +4033,10 @@ class Tensor(SimpleMathTrait):
|
||||
ret = ret.reshape(bs, oy, ox, cout).permute(0,3,1,2)
|
||||
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
|
||||
|
||||
def _metadata_wrapper(fn):
|
||||
def _wrapper(*args, **kwargs):
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
def _metadata_wrapper(fn: Callable[P, T]) -> Callable[P, T]:
|
||||
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||
|
||||
if TRACEMETA >= 2:
|
||||
|
||||
Reference in New Issue
Block a user