Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-12-10 16:35:00 +00:00
65 changed files with 1314 additions and 1053 deletions

View File

@@ -197,11 +197,11 @@ jobs:
# - name: Run LLaMA 7B on 6 GPUs
# run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
- name: Run LLaMA-3 8B BEAM
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Run LLaMA-2 70B
run: NV=1 RUN_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
@@ -380,11 +380,11 @@ jobs:
# - name: Run LLaMA 7B on 6 GPUs
# run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
- name: Run LLaMA-3 8B BEAM
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Run LLaMA-2 70B
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
@@ -508,10 +508,6 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: openpilot compile 0.9.4
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt
- name: openpilot compile 0.9.7
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_compile_0_9_7.txt
- name: validate openpilot 0.9.7
run: PYTHONPATH=. FLOAT16=0 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt
- name: benchmark openpilot 0.9.4

View File

@@ -1,7 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
DOWNLOAD_CACHE_VERSION: '7'
DOWNLOAD_CACHE_VERSION: '8'
RUN_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: .
@@ -293,22 +293,15 @@ jobs:
PYTHONPATH="." GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py --durations=20
PYTHONPATH="." GPU=1 IMAGE=2 python3 test/models/test_end2end.py TestEnd2End.test_linear_mnist
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model compile and size
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot compile3
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot fastvits model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'onnx' }}
name: Test ONNX (GPU)
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
@@ -387,7 +380,7 @@ jobs:
WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
--durations=20
test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
- name: Run process replay tests
run: |
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
@@ -439,7 +432,7 @@ jobs:
- name: Test Beam Search
run: PYTHONPATH="." METAL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Fuzz Test linearizer
run: PYTHONPATH="." METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=24 FUZZ_MAX_SIZE=1000000 python test/external/fuzz_linearizer.py
run: PYTHONPATH="." METAL=1 DEPTH=4 FUZZ_N=50 FUZZ_MAX_SIZE=1000000 python test/external/fuzz_linearizer.py
# - name: Fuzz Test models schedule
# run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
- name: Run TRANSCENDENTAL math
@@ -528,7 +521,7 @@ jobs:
if: matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
run: |
cd ${{ github.workspace }}/gpuocelot/ocelot/build
sudo ninja install -d explain
sudo cp libgpuocelot.so /usr/lib/libgpuocelot.so
- name: Install packages (amd)
if: matrix.backend == 'amd'
run: |

View File

@@ -220,7 +220,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--download_model", action="store_true", help="Download a model")
parser.add_argument("--model", type=Path, help="Model path")
parser.add_argument("--size", choices=["1B", "8B", "70B"], default="8B", help="Model size")
parser.add_argument("--size", choices=["1B", "8B", "70B"], default="1B", help="Model size")
parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
parser.add_argument("--quantize", choices=["int8", "nf4", "float16"], help="Quantization method")
parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
@@ -234,8 +234,8 @@ if __name__ == "__main__":
parser.add_argument("--profile", action="store_true", help="Output profile data")
args = parser.parse_args()
assert (args.model and not args.download_model) or (not args.model and args.download_model), "either download or provide model"
if args.download_model:
# download_model is the default without a model passed in
if args.download_model or not args.model:
if args.size == "1B":
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
args.model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", "Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")

View File

@@ -1,211 +0,0 @@
#!/usr/bin/env python3
import os, sys, io, pathlib, json, struct
import numpy as np
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
import onnx
from typing import Tuple, List, Optional, Dict, cast
from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.device import Buffer
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule
from tinygrad.ops import Ops
from tinygrad.tensor import _to_np_dtype
Device.DEFAULT = "GPU"
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
Tensor.no_grad = True
Tensor.training = False
# load the model
onnx_model = onnx.load(io.BytesIO(onnx_data))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
# run the model
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = create_schedule([ret.lazydata])
# filter schedule that don't depend on the inputs
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
depends = set(input_lb)
for si in schedule:
if any(b in depends for b in si.inputs):
for out in si.outputs: depends.add(out)
# run all kernels that don't depend on the inputs
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
return schedule, schedule_independent, inputs
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
import onnx
#import pyopencl as cl
#from extra.thneed import Thneed
import numpy as np
onnx_model = onnx.load(io.BytesIO(onnx_data))
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
Tensor.manual_seed(1337)
new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()}
if getenv("ORT"):
# test with onnxruntime
import onnxruntime as ort
onnx_session = ort.InferenceSession(onnx_data)
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
new_torch_out = onnx_output[0]
print("got ort outputs")
else:
# test with torch
from test.models.test_onnx import run_onnx_torch
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
print("got torch outputs")
# if you don't have a schedule
if eis is None:
run_onnx = get_run_onnx(onnx_model)
new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("classic self-test passed!")
return
# set inputs
for k,v in inputs.items(): v.lazydata.base.realized.copyin(new_np_inputs[k].data)
# run code (all buffers have been allocated)
GlobalCounters.reset()
output = eis[-1].bufs[0]
for ei in eis: ei.run()
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!")
if __name__ == "__main__":
onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes()
# quick test for ONNX issues
#thneed_test_onnx(onnx_data, None)
#exit(0)
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK)
print(f"{len(schedule_input)} inputs")
run_schedule(schedule_independent)
run_schedule(schedule_input)
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
schedule = memory_planner(schedule)
for si in schedule:
for b in si.outputs:
assert not b.is_allocated(), "output should not be allocated"
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
print(f"**** compiling real kernels {image_count}/{len(schedule)} images ****")
eis = list(tqdm(lower_schedule(schedule), total=len(schedule)))
print("kernel count:", len(eis))
assert len(eis) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
# new simple thneed
def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
seen_buffers = set()
input_buffers = [x.lazydata.buffer for x in inputs.values()]
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
jdat["inputs"] = {k:to_ref(v.lazydata.buffer) for k,v in inputs.items()}
jdat["outputs"] = [to_ref(eis[-1].bufs[0])]
weights = []
for i,ei in enumerate(eis):
#print("***", i)
for b in ei.bufs:
needs_load = b.is_allocated() and b not in input_buffers
#print(b, needs_load)
if b in seen_buffers: continue
seen_buffers.add(b)
if isinstance(b.dtype, ImageDType):
base_dtype = dtypes.float16 if b.dtype.fmt == 'e' else dtypes.float32
row_pitch = (b.dtype.shape[0]*4*base_dtype.itemsize + 63)//64 * 64
size = row_pitch * b.dtype.shape[1]
jdat['objects'].append({
"id": to_ref(b), "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
"width": b.dtype.shape[0], "height": b.dtype.shape[1], "row_pitch": row_pitch, "float32": b.dtype.base == dtypes.float32,
})
if needs_load:
t = Tensor.empty(b.dtype.shape, dtype=b.dtype)
t.lazydata.buffer = b
data = t.cast(dtypes.float32).pad(((0, row_pitch//(4*base_dtype.itemsize)-b.dtype.shape[0]), (0,0), (0,0))).contiguous().numpy()
# NOTE: this cast must be done in numpy for platforms that don't support half
if base_dtype == dtypes.float16: data = data.astype(np.float16)
weights.append(data.tobytes())
assert len(weights[-1]) == size, "wrong size buffer"
else:
jdat['objects'].append({
"id": to_ref(b), "arg_type": b.dtype.name + "*", "needs_load": needs_load, "size": b.nbytes,
})
if needs_load:
weights.append(b.as_buffer())
assert len(weights[-1]) == b.nbytes, "wrong size buffer"
saved_binaries = set()
binaries = []
gated_read_image_count = 0
GlobalCounters.reset()
with Context(DEBUG=max(DEBUG.value, 2)):
for ei in eis:
prg = cast(CompiledRunner, ei.prg)
assert len(prg.p.vars) == 0
if prg.p.function_name not in saved_binaries:
jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
binaries.append(prg.lib)
saved_binaries.add(prg.p.function_name)
gated_read_image_count += prg.p.src.count("?read_image")
ei.run()
jdat['kernels'].append({
"name": prg.p.function_name,
"work_dim": len(prg.p.global_size),
"global_work_size": prg.p.global_size,
"local_work_size": prg.p.local_size,
"num_args": len(ei.bufs),
"args": [to_ref(b) for b in ei.bufs],
"arg_size": [8]*len(ei.bufs),
})
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
assert gated_read_image_count <= allowed_gated_read_image, \
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
with open(output_fn, "wb") as f:
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
f.write(struct.pack("I", len(j)))
f.write(j)
for w in weights: f.write(w)
for b in binaries: f.write(b)
print("saved", f.tell(), "bytes")
FLOAT16 = getenv("FLOAT16", 0)
if FLOAT16 == 0:
try:
test_vs_onnx(onnx_data, eis, inputs)
except ModuleNotFoundError as e:
print(f"TEST NOT HAPPENING {e}")

View File

@@ -5,9 +5,10 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
from tinygrad.helpers import DEBUG, getenv
from tinygrad.tensor import _from_np_dtype
from tinygrad.engine.realize import CompiledRunner
import onnx
from onnx.helper import tensor_dtype_to_np_dtype
@@ -16,12 +17,11 @@ from extra.onnx import get_run_onnx # TODO: port to main tinygrad
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
OUTPUT = "/tmp/openpilot.pkl"
def compile():
def compile(onnx_file):
onnx_model = onnx.load(onnx_file)
Tensor.no_grad = True
Tensor.training = False
onnx_bytes = fetch(OPENPILOT_MODEL)
onnx_model = onnx.load(onnx_bytes)
run_onnx = get_run_onnx(onnx_model)
print("loaded model")
@@ -30,51 +30,103 @@ def compile():
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
print("created tensors")
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
run_onnx_jit = TinyJit(lambda **kwargs:
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
for i in range(3):
GlobalCounters.reset()
print(f"run {i}")
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
ret = run_onnx_jit(**inputs).numpy()
# copy i == 1 so use of JITBEAM is okay
if i == 1: test_val = np.copy(ret)
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
np.testing.assert_equal(test_val, ret)
np.testing.assert_equal(test_val, ret, "JIT run failed")
print("jit run validated")
# checks from compile2
kernel_count = 0
gated_read_image_count = 0
for ei in run_onnx_jit.captured.jit_cache:
if isinstance(ei.prg, CompiledRunner):
kernel_count += 1
gated_read_image_count += ei.prg.p.src.count("?read_image")
print(f"kernel_count: {kernel_count} gated_read_image_count: {gated_read_image_count}")
assert kernel_count <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
assert gated_read_image_count <= allowed_gated_read_image, \
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
with open(OUTPUT, "wb") as f:
pickle.dump(run_onnx_jit, f)
mdl_sz = os.path.getsize(onnx_bytes)
mdl_sz = os.path.getsize(onnx_file)
pkl_sz = os.path.getsize(OUTPUT)
print(f"mdl size is {mdl_sz/1e6:.2f}M")
print(f"pkl size is {pkl_sz/1e6:.2f}M")
print("**** compile done ****")
return test_val
def test(test_val=None):
with open(OUTPUT, "rb") as f:
run = pickle.load(f)
Tensor.manual_seed(100)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))}
def test_vs_compile(run, new_inputs, test_val=None):
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
# create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory)
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
# run 20 times
for _ in range(20):
st = time.perf_counter()
# Need to cast non-image inputs from numpy, this is only realistic way to run it
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}}
out = run(**inputs)
mt = time.perf_counter()
val = out['outputs'].numpy()
val = out.numpy()
et = time.perf_counter()
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
print(out, val.shape, val.dtype)
if test_val is not None: np.testing.assert_equal(test_val, val)
print("**** test done ****")
if __name__ == "__main__":
test_val = compile() if not getenv("RUN") else None
test(test_val)
# test that changing the numpy changes the model outputs
for v in new_inputs_numpy.values(): v *= 2
out = run(**inputs)
changed_val = out.numpy()
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
return val
def test_vs_onnx(new_inputs, test_val, onnx_file):
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
onnx_model = onnx.load(onnx_file)
if getenv("ORT"):
# test with onnxruntime
import onnxruntime as ort
onnx_session = ort.InferenceSession(onnx_file)
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
new_torch_out = onnx_output[0]
print("got ort outputs")
else:
# test with torch
from test.models.test_onnx import run_onnx_torch
# NOTE: we have to correct the order here
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
print("got torch outputs")
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
print("test vs onnx passed")
if __name__ == "__main__":
onnx_file = fetch(OPENPILOT_MODEL)
test_val = compile(onnx_file) if not getenv("RUN") else None
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
# same randomness as compile
Tensor.manual_seed(100)
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)

View File

@@ -1,2 +0,0 @@
#!/bin/bash
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 examples/openpilot/compile2.py

35
examples/self_tokenize.py Normal file
View File

@@ -0,0 +1,35 @@
import os, pathlib
from examples.llama3 import Tokenizer
from tabulate import tabulate
from tinygrad import fetch
from tinygrad.helpers import flatten
# llama 3 tokenizer
tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix())
def read_code(base_path):
ret = []
for path, _, files in os.walk(os.path.join(base_path, "tinygrad")):
for name in files:
if not name.endswith(".py"): continue
if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue
fullpath = os.path.join(path, name)
code = pathlib.Path(fullpath).read_text()
ret += [(fullpath.split("tinygrad/", 1)[1], code)]
return ret
if __name__ == "__main__":
ret = read_code(".")
table = []
for name,code in ret:
table.append([name, len(tokenizer.encode(name+"\x00"+code))])
print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow"))
code_str = '\x00'.join(flatten(ret))
print(f"code has {len(code_str)} chars")
newline_count = code_str.count('\n')
print(f"code has {newline_count} newlines")
encoded = tokenizer.encode(code_str)
print(f"code has {len(encoded)} tokens")

View File

@@ -17,7 +17,7 @@ canvas { display: none; }
* { text-align: center; font-family: monospace; }
</style>
<title>tinygrad has WebGPU</title>
<script src="./net.js"></script>
<script src="../../net.js"></script>
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
</head>
<body>
@@ -61,7 +61,7 @@ canvas { display: none; }
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
const reorderChannelsAndRemoveAlpha = (data) => {
const out = [];

View File

@@ -1,9 +1,10 @@
import os
from extra.export_model import compile_net, jit_model
from extra.export_model import compile_net, jit_model, dtype_to_js_type
from extra.f16_decompress import u32_to_f16
from examples.stable_diffusion import StableDiffusion
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
from tinygrad.tensor import Tensor
from tinygrad import Device
from tinygrad import Device, dtypes
from tinygrad.helpers import fetch
from typing import NamedTuple, Any, List
import requests
@@ -29,7 +30,7 @@ def convert_f32_to_f16(input_file, output_file):
rest_float32_values.tofile(f)
def split_safetensor(fn):
_, json_len, metadata = safe_load_metadata(fn)
_, data_start, metadata = safe_load_metadata(fn)
text_model_offset = 3772703308
chunk_size = 536870912
@@ -51,12 +52,12 @@ def split_safetensor(fn):
part_offset = offset - last_offset
if (part_offset >= chunk_size):
part_end_offsets.append(8+json_len+offset)
part_end_offsets.append(data_start+offset)
last_offset = offset
text_model_start = int(text_model_offset/2)
net_bytes = bytes(open(fn, 'rb').read())
part_end_offsets.append(text_model_start+8+json_len)
part_end_offsets.append(text_model_start+data_start)
cur_pos = 0
for i, end_pos in enumerate(part_end_offsets):
@@ -65,7 +66,7 @@ def split_safetensor(fn):
cur_pos = end_pos
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
f.write(net_bytes[text_model_start+8+json_len:])
f.write(net_bytes[text_model_start+data_start:])
return part_end_offsets
@@ -95,7 +96,8 @@ if __name__ == "__main__":
sub_steps = [
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)
]
prg = ""
@@ -116,19 +118,23 @@ if __name__ == "__main__":
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _, _, _) in statements])
input_names = [name for _,name in special_names.items() if "input" in name]
output_names = [name for _,name in special_names.items() if "output" in name]
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
return f"""\n var {step.name} = function() {{
{kernel_code}
return {{
"setup": async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor[0]);
const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
{bufs}
{exported_bufs}
{gpu_write_bufs}
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
@@ -147,8 +153,8 @@ if __name__ == "__main__":
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}

View File

@@ -165,10 +165,6 @@
import ClipTokenizer from './clip_tokenizer.js';
window.clipTokenizer = new ClipTokenizer();
</script>
<script type="module">
import { f16tof32GPU } from 'https://unpkg.com/f16-to-f32-gpu@0.1.0/src/index.js';
window.f16tof32GPU = f16tof32GPU;
</script>
<script src="./net.js"></script>
</head>
<body>
@@ -214,6 +210,8 @@
<canvas id="canvas" width="512" height="512"></canvas>
<script>
let f16decomp = null;
function initDb() {
return new Promise((resolve, reject) => {
let db;
@@ -416,7 +414,7 @@
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
const decodeChunkSize = 67107840;
const decodeChunkSize = 8388480;
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
console.log(allToDecomp + " bytes to decompress");
@@ -440,7 +438,8 @@
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
let result = await f16tof32GPU(chunk);
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
let result = await f16decomp(uint32Chunk);
let resultUint8 = new Uint8Array(result.buffer);
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
@@ -483,6 +482,7 @@
}
const device = await getDevice();
f16decomp = await f16tof32().setup(device, safetensorParts),
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
modelDlTitle.innerHTML = "Compiling model"

View File

@@ -12,7 +12,7 @@ if __name__ == "__main__":
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
state_dict = safe_load(get_weights_location(yolo_variant))
load_state_dict(yolo_infer, state_dict)
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,256,256))
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
dirname = Path(__file__).parent
safe_save(state, (dirname / "net.safetensors").as_posix())
with open(dirname / f"net.js", "w") as text_file:

View File

@@ -95,6 +95,7 @@
</head>
<body>
<h2>YOLOv8 tinygrad WebGPU</h2>
<h2 id="wgpu-error" style="display: none; color: red;">Error: WebGPU is not supported in this browser</h2>
<div class="video-container">
<video id="video" muted autoplay playsinline></video>
<canvas id="canvas"></canvas>
@@ -107,7 +108,7 @@
</div>
<script>
let net = null;
const modelInputSize = 256
const modelInputSize = 416;
let lastCalledTime;
let fps = 0, accumFps = 0, frameCounter = 0;
@@ -117,6 +118,7 @@
const offscreenCanvas = document.createElement('canvas');
const fpsMeter = document.getElementById('fps-meter');
const loadingContainer = document.getElementById('div-loading');
const wgpuError = document.getElementById('wgpu-error');
offscreenCanvas.width = modelInputSize;
offscreenCanvas.height = modelInputSize;
const offscreenContext = offscreenCanvas.getContext('2d');
@@ -147,7 +149,7 @@
lastCalledTime = now;
accumFps += 1/delta;
if (frameCounter++ >= 30) {
if (frameCounter++ >= 10) {
fps = accumFps/frameCounter;
frameCounter = 0;
accumFps = 0;
@@ -206,7 +208,12 @@
async function detectObjectsOnFrame(offscreenContext) {
if (!net) {
net = await loadNet(await getDevice());
let device = await getDevice();
if (!device) {
wgpuError.style.display = "block";
loadingContainer.style.display = "none";
}
net = await loadNet(device);
loadingContainer.style.display = "none";
}
let start = performance.now();
@@ -239,7 +246,7 @@
}
const getDevice = async () => {
if (!navigator.gpu) error("WebGPU not supported.");
if (!navigator.gpu) return false;
const adapter = await navigator.gpu.requestAdapter();
return await adapter.requestDevice();
};

Binary file not shown.

View File

@@ -79,6 +79,9 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
return '\n'.join(cprog)
def dtype_to_js_type(dtype: DType) -> str:
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
@@ -92,10 +95,12 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
return f"""
{web_utils["getTensorBuffer"]}

16
extra/f16_decompress.py Normal file
View File

@@ -0,0 +1,16 @@
from tinygrad import Tensor
def bit_extract(x: Tensor, e: int, s: int) -> Tensor:
mask = (1 << (e - s + 1)) - 1
return (x >> s) & mask
def u16_to_f16(x: Tensor) -> Tensor:
sign = bit_extract(x, 15, 15).float()
exponent = bit_extract(x, 14, 10).float()
fraction = bit_extract(x, 9, 0).float()
return sign.where(-1, 1) * exponent.where((exponent - 15.0).exp2() * (1 + fraction / 1024.0), 6.103515625e-5 * (fraction / 1024.0))
def u32_to_f16(oo: Tensor) -> Tensor:
f1 = u16_to_f16(oo>>16)
f2 = u16_to_f16(oo&0xFFFF)
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()

View File

@@ -1,40 +0,0 @@
import numpy as np
from tinygrad import Device, dtypes, Tensor
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
def bit_extract(x, s, e) -> Tensor:
# extract the top bits we don't want
top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1))
x = (x - top_bits) / (1<<e)
return x.contiguous()
def u16_to_f16(x):
sign = bit_extract(x, 15, 15).float()
exponent = bit_extract(x, 14, 10).float()
fraction = bit_extract(x, 9, 0).float()
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
def u32_to_f16(oo):
oo1 = (oo/0x10000).floor().contiguous()
# TODO: this is wrong and unextractable until we do this math in u32
oo2 = (oo-(oo1*0x10000)).floor().contiguous()
f1 = u16_to_f16(oo1)
f2 = u16_to_f16(oo2)
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
if __name__ == "__main__":
# random float16
Tensor.manual_seed(2)
a = Tensor.randn(100, dtype=dtypes.float16)
# this converts it to u32 on disk
oo = a.to("disk:/tmp/f16").cast(dtypes.uint32)[:50].to(Device.DEFAULT).realize()
# convert to 2xf16 using tinygrad math ops
f16 = u32_to_f16(oo)
ref = a.numpy()
out = f16.numpy().astype(np.float16)
print(ref-out)
np.testing.assert_allclose(ref, out)

View File

@@ -1,14 +1,12 @@
from __future__ import annotations
from typing import List, Dict, Union
import importlib
from functools import lru_cache
from typing import List, Dict, Union, Callable, Any, Sequence
import importlib, functools
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from tinygrad.dtype import ConstType, DType
from tinygrad import Tensor, dtypes
from tinygrad.helpers import getenv, DEBUG, all_same
from tinygrad.dtype import DType, ConstType
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
@@ -17,178 +15,134 @@ except ImportError:
def tensor_dtype_to_np_dtype(tensor_dtype:int) -> np.dtype: return TENSOR_TYPE_TO_NP_TYPE[tensor_dtype]
cache_misses = 0
@lru_cache(None)
def _cached_to_python_const(t:Tensor, tobytes): return t.data().tobytes() if tobytes else t.tolist()
@functools.lru_cache(None)
def _cached_to_python_const(t:Tensor):
if t.dtype is dtypes.uint8: return t.data().tobytes()
if 0 in t.shape: return []
return t.tolist()
# Tensor -> python value cache for parameters
def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
def to_python_const(t) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
if not isinstance(t, Tensor): return t
global cache_misses
ret = _cached_to_python_const(t, tobytes)
ret = _cached_to_python_const(t)
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
print(f"Cache miss for {t}, {tobytes=}")
print(f"Cache miss for {t}")
cache_misses = info.misses
return ret
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
# TODO: use dtypes.float16 for FLOAT16
DTYPE_MAP: Dict[TensorProto.DataType, DType] = {
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float, TensorProto.FLOAT8E4M3FNUZ:dtypes.float,
TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
# TODO: use real float16
# src: onnx/mapping.py
DTYPE_MAP: Dict[TensorProto.DataType | int, DType] = {
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float,
TensorProto.FLOAT8E4M3FNUZ:dtypes.float, TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
}
def dtype_parse(onnx_dtype: TensorProto.DataType | int) -> DType:
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
# src: onnx/onnx_ml_pb2.pyi
ATTRIBUTE_MAP: Dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
}
def attribute_parse(onnx_attribute: AttributeProto):
if onnx_attribute.type not in ATTRIBUTE_MAP:
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
def buffer_parse(inp: TensorProto) -> Tensor:
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype_parse(inp.data_type), requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).copy().reshape(tuple(inp.dims)),
dtype=dtype_parse(inp.data_type), requires_grad=False)
raise NotImplementedError(f"buffer with data type {TensorProto.DataType.Name(inp.data_type)} is not supported")
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
else:
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type not in DTYPE_MAP:
raise NotImplementedError(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
return Tensor(data.reshape(tuple(inp.dims)), requires_grad=False)
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
if a.type == AttributeProto.FLOAT: return float(a.f)
elif a.type == AttributeProto.INT: return int(a.i)
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise NotImplementedError(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise RuntimeError(f"can't parse {a.type} {a}")
tensors: Dict[str, Tensor] = {}
# get weights and biases
for inp in onnx_model.graph.initializer:
tensors[inp.name] = buffer_parse(inp)
# preparse the attributes
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = {x.name:attribute_parse(x) for x in n.attribute}
if n.domain: domain = n.domain
# model initialization data
model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
# model descriptions
# TODO: need a better way of controlling training vs non-training
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
onnx_model_version = onnx_model.opset_import[0].version
# mapping from onnx ops to tensor.py ops
tensor_methods = {
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
}
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
# parses and validates inputs based on their shape and dtype specified by model
def prepare_input(user_input:Any, model_input:ValueInfoProto):
type_proto = model_input.type
if type_proto.HasField("optional_type"):
if user_input is None: return Tensor(None)
type_proto = type_proto.optional_type.elem_type
if type_proto.HasField("sequence_type"):
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
# TODO: need true float16 for dtype checking
# if not all(t.dtype is dtype for t in sequence): raise RuntimeError(f"{model_input.name} received wrong dtype, expected {dtype}")
return sequence
if type_proto.HasField("tensor_type"):
dtype = dtype_parse(type_proto.tensor_type.elem_type)
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
# TODO: need true float16 for dtype checking
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} received dtype {inp.dtype}, expected {dtype}")
for d,onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
# NOTE: dim is a variable dimension when `dim_param` is specified, e.g. dim {dim_param: "N"} is a variable dim
if onnx_dim.dim_param is None and onnx_dim.dim_value != user_input.shape[d]:
raise RuntimeError(f"{model_input.name} received value {user_input.shape[d]} on dim {d}, expected {onnx_dim.dim_value}")
return tensor
type_field_names = [field.name for field,_ in type_proto.ListFields()]
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
input_tensors: Dict[str, Tensor | List[Tensor]] = {}
for model_input in onnx_model.graph.input:
name = model_input.name
if name in tensors: continue
shape = type_parse(model_input.type)
if name in inputs:
if isinstance(inputs[name], Tensor):
input_tensors[name] = inputs[name]
elif isinstance(inputs[name], list):
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
# TODO: this is just to make training tests pass, need a principled way to handle training vs non-training
elif domain == "ai.onnx.preview.training":
input_tensors[name] = Tensor(inputs[name], requires_grad=True)
else:
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
if shape: # if only input_tensor is not variable type
ts = input_tensors[name]
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
else:
raise RuntimeError(f"no data for {name} with shape {shape}")
if model_input.name in inputs: input_tensors[model_input.name] = prepare_input(inputs[model_input.name], model_input)
elif model_input.name not in model_parameters: raise RuntimeError(f"Please provide input data for {model_input.name}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
if x in model_parameters: return model_parameters[x]
if x in intermediate_tensors: return intermediate_tensors[x]
if x != "": return input_tensors[x]
return None
for num,n in enumerate(onnx_model.graph.node):
inp: List[Tensor] = []
if debug >= 3: print("inputs:")
for x in n.input:
t = fetch_tensor(x)
if debug >= 3: print(f"\t{x} - {t}")
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
inp = [fetch_tensor(x) for x in n.input]
opt = model_attributes[num]
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
if debug >= 3: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {t}" for i,(x,t) in enumerate(zip(n.input, inp))))
if n.op_type in tensor_methods:
ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
# NOTE some ops live here because they require access to some local variables
# have to use n.output for cases when num_outputs is absent
if n.op_type in onnx_ops.tensor_methods:
ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
elif n.op_type == "Split":
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else to_python_const(inp[1])
if split is None:
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
for i in range(inp[0].shape[axis] % len(n.output)):
split[i] += 1
i, ret = 0, []
arg = [None] * inp[0].ndim
for s in split:
arg[axis] = (i,i+s)
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
# need to check onnx_model_version
elif n.op_type == "Slice":
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = list(range(inp[0].ndim)) if len(inp) <= 3 else to_python_const(inp[3].cast(dtypes.int32))
steps = inp[4].cast(dtypes.int32).tolist() if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = to_python_const(starts), to_python_const(ends)
arg = [(0,x,1) for x in inp[0].shape]
for i, axis in enumerate(axes):
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
if starts[i] < 0: starts[i] += inp[0].shape[axis]
if ends[i] < 0: ends[i] += inp[0].shape[axis]
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
arg[axis] = (starts[i], ends[i], steps[i])
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0][tuple([slice(s,e,st) for s,e,st in arg])]
# need to call backward on intermediate_tensors
axis, n_outputs = opt.get('axis', 0), opt.get('num_outputs') or len(n.output)
sz = inp[0].shape[axis]
sizes = to_python_const(inp[1]) if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
ret = inp[0].split(sizes, axis)
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
@@ -209,16 +163,12 @@ def get_run_onnx(onnx_model: ModelProto):
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise NotImplementedError(f"op_type {n.op_type} not supported")
# finalization after running the op
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
if debug >= 2: print("outputs:")
for i in range(len(n.output)):
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
intermediate_tensors[n.output[i]] = ret[i]
if num == ONNXLIMIT:
output_tensor_names = n.output
break
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i]
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
if num == ONNXLIMIT: return {name:intermediate_tensors[name] for name in n.output}
return {x.name:intermediate_tensors[x.name] for x in onnx_model.graph.output}
return run_onnx

View File

@@ -3,13 +3,9 @@ from typing import Union, Tuple, Optional, List, Any, cast
from tinygrad.tensor import Tensor, _broadcast_shape
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.helpers import prod, flatten
from extra.onnx import DTYPE_MAP, to_python_const
from extra.onnx import dtype_parse, to_python_const
import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
"Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}
# **************** Free Ops ****************
def Identity(x: Tensor): return x
@@ -21,12 +17,16 @@ def LessOrEqual(x:Tensor,y:Tensor): return x <= y
def Greater(x:Tensor,y:Tensor): return x > y
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
def Equal(x:Tensor,y:Tensor): return x == y
def BitwiseNot(x:Tensor): return ~x
def BitwiseOr(x:Tensor, y:Tensor): return x | y
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
def Mean(*data_0): return Sum(*data_0) / len(data_0)
# NOTE: does not support saturate
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
def Cast(x: Tensor, to: int, saturate=1): return x.cast(dtype_parse(to))
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
# **************** Simple Ops ****************
@@ -91,6 +91,14 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
def Slice(data: Tensor, starts:Tensor, ends:Tensor, axes:Optional[Tensor]=None, steps:Optional[Tensor]=None):
if axes is None: axes = list(range(data.ndim))
if steps is None: steps = [1] * data.ndim
starts, ends, axes, steps = (to_python_const(x) for x in (starts, ends, axes, steps))
slices = [slice(0,x,1) for x in data.shape]
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
return data[tuple(slices)]
def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = to_python_const(axes)
axes = [data._resolve_dim(x) for x in axes]
@@ -389,11 +397,10 @@ def CenterCropPad(t: Tensor, shape: Tensor, axes=None):
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
depth = to_python_const(depth)
depth = int(to_python_const(depth))
# Scalar or Rank 1 tensor containing exactly one element
depth, indices = depth[0] if isinstance(depth, list) else depth, (indices < 0).where(indices+depth, indices),
if axis < 0: axis += indices.ndim + 1
return (indices[:,None] == Tensor.arange(int(depth)).reshape((int(depth),) + (1,)*(indices.ndim-axis))).where(values[1], values[0])
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
def Compress(inp: Tensor, condition: Tensor, axis=None):
if axis is None:
@@ -405,7 +412,7 @@ def Compress(inp: Tensor, condition: Tensor, axis=None):
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
def EyeLike(x: Tensor, dtype=None, k=0):
ret = Tensor.eye(cast(int, min(x.shape)), dtype=DTYPE_MAP[dtype] if dtype else x.dtype)
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype else x.dtype)
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.size(0)-k) for d in x.shape))
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
@@ -424,7 +431,7 @@ def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
try: import PIL.Image
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream, True)))
img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream)))
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
if pixel_format == "RGB": return Tensor(np.array(img))
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
@@ -461,8 +468,7 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).expand(*x.shape, vocab_size)
return (vocab_counter == x.unsqueeze(-1).expand(*x.shape, vocab_size)) @ weight
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
# bert embedding layer
if epsilon is None: epsilon = 1e-12

View File

@@ -5,9 +5,9 @@ from test.external.process_replay.process_replay import _pmap
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
def extract_ast(*args) -> bool:
def extract_ast(*args) -> None:
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
return args[-1]
return None
if __name__ == "__main__":
_pmap("kernel", extract_ast)

View File

@@ -1,5 +1,6 @@
#!/bin/bash
export PAGE_SIZE=1
export PYTHONPATH=.
export LOGOPS=/tmp/ops
export RUN_PROCESS_REPLAY=1
rm $LOGOPS
@@ -24,5 +25,5 @@ JIT=2 BIG=1 MPS=1 python -m pytest test/test_speed_v_torch.py
# extract, sort and uniq
extra/optimization/extract_dataset.py
sort -u /tmp/ops > /tmp/sops
sort -u /tmp/ops > /tmp/sops
ls -lh /tmp/ops /tmp/sops

View File

@@ -29,7 +29,7 @@ setup(name='tinygrad',
'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
'linting': [
"pylint",
"mypy==1.11.2",
"mypy==1.13.0",
"typing-extensions",
"pre-commit",
"ruff",

View File

@@ -70,11 +70,16 @@ backend_test.exclude('BFLOAT16') # not supported in numpy
# TODO: fix these with true onnx float16
backend_test.exclude('to_FLOAT16')
backend_test.exclude('cast_no_saturate')
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
backend_test.exclude('test_max_float16_cpu')
backend_test.exclude('test_min_float16_cpu')
backend_test.exclude('test_pow_types_int*')
backend_test.exclude('test_convinteger_*')
backend_test.exclude('test_matmulinteger_*')
backend_test.exclude('test_dequantizelinear_int4_cpu')
backend_test.exclude('test_dequantizelinear_uint4_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')
@@ -117,7 +122,6 @@ backend_test.exclude('test_affine_grid_3d_expanded_cpu')
backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
# unsupported (strange) ops
backend_test.exclude('test_bitwise_*')
backend_test.exclude('test_blackmanwindow_*')
backend_test.exclude('test_bernoulli_*')
backend_test.exclude('test_det_*')

View File

@@ -54,6 +54,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
# try recreate
try:
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
if good is None: continue
except Exception as e:
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
for x in args[:-1]: logging.info(x)

View File

@@ -87,7 +87,7 @@ class TestKernelSpeed(unittest.TestCase):
# NOTE: tiny7 was slower than tiny12
# TODO: why are convs so slow?!?
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=36, amd_tflops=24)
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=24)
# theoretical is nv_tflops=165, amd_tflops=123
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=120, amd_tflops=80)
@@ -95,7 +95,7 @@ class TestKernelSpeed(unittest.TestCase):
# theoretical is nv_gbs=1008, amd_gbs=960
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=780)
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=760)
if __name__ == '__main__':
unittest.main()

View File

@@ -22,7 +22,7 @@ def consec(shape, start=1):
def set_(reference: Tensor, shape, strides, offset):
if reference.lazydata.base.realized is None: reference.realize()
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
return strided
@@ -1062,9 +1062,9 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])
# indexing by a scalar should slice (not copy)
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int32)])
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
# scalar indexed with scalar
r = Tensor.randn()
@@ -1105,6 +1105,20 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_allclose(9.9, r, rtol=1e-7)
'''
def test_getitem_casted_scalars_folding(self):
Tensor.manual_seed(0)
# cast of const is just another const, don't need extra kernels for this
a = Tensor.randn(2, 3)
one = Tensor(1, dtype=dtypes.int64)
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
def test_getitem_scalars_simple_folding(self):
a = Tensor.randn(2, 3)
zero = Tensor(0, dtype=dtypes.int64)
one = Tensor(1, dtype=dtypes.int64)
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
def test_basic_advanced_combined(self):
# From the NumPy indexing example
x = Tensor.arange(0, 12).reshape(4, 3)
@@ -1555,4 +1569,4 @@ class TestNumpy(unittest.TestCase):
'''
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -158,6 +158,37 @@ class TestReduceOpsConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
def test_bool_zero_max(self):
_check_ast_count(0, Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)))
np.testing.assert_equal(Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)).numpy(), False)
def test_zero_size_ops(self):
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]: # lambda x:x.max() NOTE: numpy gives "reduction operation maximum which has no identity"
_check_ast_count(0, reduceop(Tensor.empty(1, 0)))
np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty(shape)))
def test_zero_size_ops_view(self):
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
_check_ast_count(0, reduceop(Tensor.empty(1, 0, 4).permute((1, 2, 0)).contiguous()))
np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty((shape))))
def test_zero_size_ops_realized(self):
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
_check_ast_count(0, reduceop((Tensor.randn(0, 1)+1).realize()))
np.testing.assert_equal(reduceop((Tensor.randn(shape:=(0, 1))+1).realize()).numpy(), reduceop(np.empty(shape)))
def test_zero_size_realize_folded(self):
# non contiguous folded output doesn't realize
_check_ast_count(0, Tensor.empty(1, 0).sum())
# contiguous folded const can still schedule
a = Tensor.empty(1, 0).sum().contiguous()
_check_ast_count(2, a+2)
self.assertIsNotNone(a.lazydata.base.realized)
np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
# otherwise we just fuse it
_check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())
np.testing.assert_equal((Tensor.empty(1, 0).sum()+2).numpy(), 2)
def test_const_prod(self):
_check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))

View File

@@ -3,7 +3,7 @@ import unittest
from tinygrad import Tensor, dtypes, Device
import operator
import numpy as np
from hypothesis import given, strategies as strat, settings
from hypothesis import given, strategies as strat, settings, HealthCheck
from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.engine.schedule import create_schedule
@@ -86,7 +86,7 @@ def universal_test_unary(a, dtype, op):
def universal_test_cast(a, in_dtype, dtype):
tensor_value = Tensor([a], dtype=in_dtype).cast(dtype)
numpy_value = np.array([a]).astype(_to_np_dtype(dtype))
numpy_value = np.array([a], dtype=_to_np_dtype(in_dtype)).astype(_to_np_dtype(dtype))
np.testing.assert_equal(tensor_value.numpy(), numpy_value)
def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
@@ -178,6 +178,28 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
@settings(suppress_health_check=[HealthCheck.filter_too_much])
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
overflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
@unittest.expectedFailure
def test_unsafe_cast_float_to_int_failure(self):
val = float(dtypes.max(dtypes.int32) - 1)

View File

@@ -398,6 +398,14 @@ class TestJit(unittest.TestCase):
for i in range(5):
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
def test_jitted_clone(self):
def f(a): return a.clone().realize()
jf = TinyJit(f)
for _ in range(5):
a = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
ja = jf(a)
np.testing.assert_allclose(a.numpy(), ja.numpy(), atol=1e-4, rtol=1e-5)
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
def test_jitted_transfers(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
@@ -483,5 +491,52 @@ class TestJitInsideJit(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
g(Tensor([1])).realize()
class TestCopyInsideJit(unittest.TestCase):
def test_copy_inside_jit(self):
@TinyJit
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
for _ in range(5):
# create a Tensor in CLANG
a = Tensor.rand(16,16,device="CLANG").realize()
b = Tensor.rand(16,16).realize()
out = add(a,b)
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
class TestJitPrune(unittest.TestCase):
def test_simple_prune(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_noprune.captured.jit_cache) == 2
for _ in range(3):
a = Tensor.rand(16).realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
assert len(w2_prune.captured.jit_cache) == 1
def test_prune_w_copy_correct(self):
weights = Tensor.rand(16).realize()
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
w2_noprune = TinyJit(w2)
w2_prune = TinyJit(w2, prune=True)
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_noprune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
for _ in range(3):
a = Tensor.rand(16, device="CLANG").realize()
out = w2_prune(a)
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
if __name__ == '__main__':
unittest.main()

View File

@@ -63,12 +63,12 @@ class TestLazyBuffer(unittest.TestCase):
def test_const_dtype(self):
lb: LazyBuffer = Tensor([1], dtype=dtypes.int).lazydata
assert lb.const_like(1).base.arg == 1
assert type(lb.const_like(1).base.arg) is int
assert lb.const_like(1).const_arg == 1
assert type(lb.const_like(1).const_arg) is int
lb: LazyBuffer = Tensor([1], dtype=dtypes.float).lazydata
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float
assert lb.const_like(1).const_arg == 1.0
assert type(lb.const_like(1).const_arg) is float
def test_forced_realized_alu(self):
a = Tensor.randn(2, 2).realize()

View File

@@ -108,6 +108,25 @@ class TestLinearizer(unittest.TestCase):
if skip and i in skip: continue
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
@unittest.expectedFailure
def test_const_alu_indexing(self):
st = ShapeTracker.from_shape((4,)).to_uop()
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
op = load+UOp.const(dtypes.float, 1.0)*UOp.const(dtypes.float, -1)
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
Tensor.manual_seed(0)
x = Tensor.randn(4,).realize()
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1*-1], opts=[])
def test_const_alu_indexing_one_const_fine(self):
st = ShapeTracker.from_shape((4,)).to_uop()
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
op = load+UOp.const(dtypes.float, 1.0)
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
Tensor.manual_seed(0)
x = Tensor.randn(4,).realize()
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1], opts=[])
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")

View File

@@ -670,8 +670,8 @@ class TestMultiTensor(unittest.TestCase):
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
t.shard_(devices, axis=0)
assert all([lb is lb.base and lb.buffer.base.size == 4 * 16 for lb in t.lazydata.lbs])
t.shard_(devices, axis=0).realize()
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs])
def test_clone(self):
t = Tensor.rand(16, 16).shard(devices_2, axis=None)

View File

@@ -36,7 +36,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
try:
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}"
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
if np.issubdtype(tinygrad_output.dtype, np.floating):
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
else:
np.testing.assert_equal(tinygrad_output, torch_output)
except Exception as e:
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")
@@ -71,6 +74,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
np.random.seed(0)
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
for i in range(len(ts)):
# NOTE: torch default int64 for python ints input
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
return ts, tst
@@ -312,8 +318,7 @@ class TestOps(unittest.TestCase):
def _test_cmp(self, fxn, reverse=True):
# test different dtypes
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
# test broadcasting
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
@@ -500,11 +505,18 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
if is_dtype_supported(dtypes.uint64):
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
# 1 // 0 is device dependent, but it should not raise
Tensor([1]).idiv(1).realize()
if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI
# ... because if might be in a where branch that the output is well defined
t = Tensor([-1, 0, 1, 2])
np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0])
def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)
@@ -548,6 +560,7 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x**1.2, low=-30, high=-27)
a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True)
helper_test_op([], lambda: b**1.1, lambda: a**1.1)
def test_pow_const(self):
helper_test_op([(45,65)], lambda x: x**1.0)
helper_test_op([(45,65)], lambda x: x**-1.0)
@@ -561,6 +574,25 @@ class TestOps(unittest.TestCase):
# TODO: fix backward, should be nan
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
def test_pow_int(self):
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
for base in ([1, 2, 3], [-1, -2, -3]):
for exponent in ([2, 3, 4], [-2, -3, -4]):
_test(base, exponent)
# NOTE: torch 0 ** -1 is 0
_test([0, 0, 0], [0, 1, 2])
np.testing.assert_equal((Tensor(11) ** Tensor(7)).item(), 11 ** 7)
np.testing.assert_equal((Tensor([11]) ** Tensor(7)).item(), 11 ** 7)
# TODO: fix non-precise int pow
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)
# pow to a const int
helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int),
lambda: Tensor([2]) ** Tensor(-2), forward_only=True)
def test_sqrt(self):
helper_test_op([(45,65)], lambda x: x.sqrt())
helper_test_op([()], lambda x: x.sqrt())
@@ -784,11 +816,6 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish)
helper_test_op([()], torch.nn.functional.mish, Tensor.mish)
def test_multinomial(self):
# NOTE: this is random, so it has a very large atol
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1).type(torch.int32),
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
def test_small_cumsum(self):
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
def test_simple_cumsum(self):
@@ -832,17 +859,27 @@ class TestOps(unittest.TestCase):
def test_argmax(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]])
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
# regression test for bitwise_not then argmax
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]])
# NOTE: torch does not support this on bool
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]])
def test_argmin(self):
# check if it returns the first index for multiple occurences
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]])
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]])
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
@@ -850,6 +887,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[-2**31, 0]])
# NOTE: torch does not support this on bool
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
def test_einsum(self):
# matrix transpose
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
@@ -1085,11 +1128,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[True, False]])
def test_max(self):
helper_test_op([(45,3)], lambda x: x.max())
@@ -1098,11 +1140,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]])
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_any(self):
@@ -1198,17 +1239,18 @@ class TestOps(unittest.TestCase):
lambda x: Tensor.stack(*x.std_mean(correction=5)))
helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x, keepdim=True, correction=0)),
lambda x: Tensor.stack(*x.std_mean(keepdim=True, correction=0)))
helper_test_op([(1,0,3,0,5)], lambda x: torch.stack(torch.std_mean(x, axis=(1,3))),
lambda x: Tensor.stack(*x.std_mean(axis=(1,3))))
helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))),
lambda x: Tensor.stack(*x.std_mean(axis=(1,2))))
@unittest.skip("TODO: this fails because of loaded nan in mul folding")
def test_std_mean_loaded_nan(self):
helper_test_op([(1,0,3,0,5)], lambda x: torch.stack(torch.std_mean(x, axis=(1,3))),
lambda x: Tensor.stack(*x.std_mean(axis=(1,3))))
def test_softmax(self):
# exceed per kernel buffer limit with backward
forward_only = (Device.DEFAULT == "WEBGPU")
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
def test_softmax_other_axis(self):
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
@@ -2033,6 +2075,20 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
def test_max_pool2d_ceil_mode(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True))
def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_avg_pool2d(self):
shape = (32,2,111,28)
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
@@ -2062,11 +2118,53 @@ class TestOps(unittest.TestCase):
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
def test_avg_pool2d_ceil_mode(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5)
def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
def test_avg_pool2d_ceil_mode_include_pad(self):
shape = (1,1,6,6)
for ksz in [(3,3), 3, (3,2), 4]:
with self.subTest(kernel_size=ksz):
helper_test_op([shape],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True), rtol=1e-5)
def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self):
# sliding window ignored from end region
helper_test_op([(1,1,5,5)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True),
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
def test_global_avg_pool2d(self):
helper_test_op([(32,2,111,28)],
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
# TODO: linearizer block error
@unittest.expectedFailure
def test_avg_pool3d_failure(self):
with Context(NOOPT=0):
helper_test_op([(1,1,16,16,16)],
lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), rtol=1e-5, forward_only=True)
def test_avg_pool3d_noopt(self):
with Context(NOOPT=1):
helper_test_op([(1,1,16,16,16)],
lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False),
lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), rtol=1e-5, forward_only=True)
def test_interpolate_linear(self):
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
helper_test_op([(2,3)+in_sz],
@@ -2191,7 +2289,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())
@unittest.skip("this test is broken #862")
def test_max_inf(self):
def test_max_nan(self):
n = Tensor([1, float("nan")]).max().numpy()
assert math.isnan(n.item()), f"{n.item()} is not nan"
@@ -2429,45 +2527,39 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
lambda x,y: x.cross_entropy(y, label_smoothing=ls))
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
self.helper_test_exception([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_ignore_index(self):
logits = [[2.0, 0.5, -1.0],
[1.5, 2.5, -0.5],
@@ -2475,7 +2567,7 @@ class TestOps(unittest.TestCase):
targets = [0, 1, 2]
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
torch.clip(y,0).type(torch.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
forward_only=True, vals=[logits, targets])
def test_one_hot(self):
@@ -2497,8 +2589,7 @@ class TestOps(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
@@ -2508,7 +2599,6 @@ class TestOps(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
@unittest.skip('this is broken for negative numbers')
def test_cast(self):
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
@@ -2533,7 +2623,6 @@ class TestOpsUint8(unittest.TestCase):
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_min(self):
helper_test_op(None,
lambda x: x.type(torch.uint8).min(),

View File

@@ -2,6 +2,7 @@ import unittest, pickle, types
import numpy as np
from tinygrad import Tensor, TinyJit, Variable, dtypes
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import GlobalCounters
from tinygrad.ops import PatternMatcher, UPat, UOp
class TestPickle(unittest.TestCase):
@@ -24,10 +25,29 @@ class TestPickle(unittest.TestCase):
pickle.dumps(sym)
def test_pickle_realized_tensor(self):
print("** init")
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
t_values = t.numpy()
del t # free buffers
print("** post pickle")
init = GlobalCounters.kernel_count
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
np.testing.assert_equal(t_values, t2.numpy())
# expect at most one COPY kernel
self.assertLessEqual(GlobalCounters.kernel_count-init, 1)
def test_pickle_realized_tensor_alt(self):
print("** init")
t = Tensor.rand(10, 10).to("CLANG").realize()
st = pickle.dumps(t)
t_values = t.numpy()
del t # free buffers
print("** post pickle")
init = GlobalCounters.kernel_count
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t_values, t2.numpy())
self.assertEqual(GlobalCounters.kernel_count-init, 0)
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)

View File

@@ -13,12 +13,12 @@ from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from tinygrad.engine.lazy import LazyBuffer
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass
@@ -1344,6 +1344,7 @@ class TestSchedule(unittest.TestCase):
Tensor.ones(5, 5).contiguous().schedule()
self.assertEqual(GlobalCounters.mem_used-base, 0)
@unittest.skip("TODO: this is consistently creating non reproducible failures")
def test_schedule_mem_used_with_inputs(self):
base = GlobalCounters.mem_used
x = Tensor.ones(256).contiguous().realize()
@@ -1924,7 +1925,7 @@ class TestView(unittest.TestCase):
a = UOp(Ops.VIEW, dtypes.float, (UOp.new_buffer(Device.DEFAULT, 121, dtypes.float), UOp(Ops.EMPTY, dtypes.float)), st)
b = a.pad(pad_arg:=((0, 0), (0, 0), (18, 0)))
self.assertEqual(b.st, st.pad(pad_arg))
self.assertIs(b, b.const_like(0))
self.assertIs(b.base.src[1], UOp.const(dtypes.float, 0))
def test_partial_mask(self):
# partial masked out does not degrade into CONST
@@ -1955,8 +1956,8 @@ class TestBigGraph(unittest.TestCase):
def test_sink_childless_const_alt_expanded(self):
# this is a real STORE of CONST (post expand)
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
y = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
out = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 2, dtypes.int), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
big_graph = big_graph_rewrite(out.sink(), realizes:={})
self.assertIs(big_graph, out.sink())
self.assertEqual(len(realizes), 1)

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.device import Buffer
from tinygrad.engine.lazy import view_supported_devices
from tinygrad.ops import view_supported_devices
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
class TestSubBuffer(unittest.TestCase):

View File

@@ -573,7 +573,7 @@ class TestZeroShapeTensor(unittest.TestCase):
a = t.reshape(0)
assert a.shape == (0,)
np.testing.assert_equal(a.numpy(), np.zeros((0,)))
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
# cannot reshape from size 0 to size 1
a = t.reshape(())

View File

@@ -405,6 +405,14 @@ class TestUOpMethod(unittest.TestCase):
self.assertEqual(const._device, None)
with self.assertRaises(AssertionError): const.device
def test_const_arg(self):
var = UOp.variable("a", 1, 10)
with self.assertRaises(AssertionError): UOp.const(dtypes.int, var).const_arg
const = UOp.const(dtypes.int, 1)
self.assertEqual(const.const_arg, 1)
tensor_const = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), const), ShapeTracker.from_shape(()))
self.assertEqual(tensor_const.const_arg, 1)
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)

View File

@@ -13,7 +13,7 @@ def time_tensor_numpy(out:Tensor):
N = 4096
class TestZeroCopy(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "METAL"}, "device isn't zero copy")
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "METAL"}, "device isn't zero copy")
def test_zero_copy_from_default_to_cpu(self):
demo = Tensor.rand(1).realize()
t1 = time_tensor_numpy(demo)

View File

@@ -1,6 +1,7 @@
import unittest
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
from tinygrad.tensor import Tensor, Device
from tinygrad import dtypes
import json
class MockMultiInputModel:
@@ -45,6 +46,24 @@ class TextModelExport(unittest.TestCase):
for i, exported_output in enumerate(prg["outputs"]):
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Testing WebGPU specific model export behavior")
class TextModelExportWebGPU(unittest.TestCase):
def test_exported_input_output_dtypes(self):
class MyModel:
def forward(self, *inputs): return tuple([(inp+2).cast(inp.dtype) for inp in inputs])
model = MyModel()
# [:-1] because "ulong" and "long" is not supported
inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)]
prg, _, _, _ = export_model(model, "webgpu", *inputs)
expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"]
for i, expected_buffer_type in enumerate(expected_buffer_types):
dt = inputs[i].dtype
expected_arr_prefix = f"{expected_buffer_type}{dt.itemsize*8}"
# test input buffers
self.assertIn(f"new {expected_arr_prefix}Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
# test output buffers
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,15 @@
import unittest
from extra.f16_decompress import u32_to_f16
from tinygrad.tensor import Tensor
from tinygrad.device import Device, is_dtype_supported
from tinygrad import dtypes
import numpy as np
class TestF16Decompression(unittest.TestCase):
def test_u32_to_f16(self):
a = Tensor.randn(50, dtype=dtypes.float16, device=None if is_dtype_supported(dtypes.float16) else "CLANG:0")
f16_as_u32 = a.bitcast(dtypes.uint32) if is_dtype_supported(dtypes.float16) else a.bitcast(dtypes.uint32).to(Device.DEFAULT)
f16 = u32_to_f16(f16_as_u32)
ref = a.numpy()
out = f16.numpy().astype(np.float16)
np.testing.assert_allclose(out, ref)

View File

@@ -134,15 +134,23 @@ class TestSafetensors(unittest.TestCase):
for k in f.keys():
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
def test_huggingface_enet_safetensors(self):
# test a real file
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
def _test_huggingface_enet_safetensors(self, fn):
state_dict = safe_load(fn)
assert len(state_dict.keys()) == 244
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
def test_huggingface_enet_safetensors(self):
# test a real file
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
self._test_huggingface_enet_safetensors(fn)
def test_huggingface_enet_safetensors_fromurl(self):
# test tensor input
t = Tensor.from_url("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
self._test_huggingface_enet_safetensors(t)
def test_metadata(self):
metadata = {"hello": "world"}
safe_save({}, temp('metadata.safetensors'), metadata)
@@ -353,10 +361,10 @@ class TestPathTensor(unittest.TestCase):
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_with_device(self):
t = Tensor(self.test_file, device="CPU")
t = Tensor(self.test_file, device="CLANG")
self.assertEqual(t.shape, (100,))
self.assertEqual(t.dtype, dtypes.uint8)
self.assertEqual(t.device, "CPU")
self.assertEqual(t.device, "CLANG")
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
def test_path_tensor_empty_file(self):
@@ -381,8 +389,8 @@ class TestPathTensor(unittest.TestCase):
def test_path_tensor_copy_to_device(self):
t = Tensor(self.test_file)
t_cpu = t.to("CPU")
self.assertEqual(t_cpu.device, "CPU")
t_cpu = t.to("CLANG")
self.assertEqual(t_cpu.device, "CLANG")
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
if __name__ == "__main__":

View File

@@ -1,8 +1,66 @@
import unittest, tarfile, io, os, pathlib
import unittest, tarfile, io, os, pathlib, tempfile
import numpy as np
from tinygrad import Tensor
from tinygrad.nn.state import tar_extract
class TestTarExtractFile(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.mkdtemp()
self.test_files = {
'file1.txt': b'Hello, World!',
'file2.bin': b'\x00\x01\x02\x03\x04',
'empty_file.txt': b''
}
self.tar_path = os.path.join(self.test_dir, 'test.tar')
with tarfile.open(self.tar_path, 'w') as tar:
for filename, content in self.test_files.items():
file_path = os.path.join(self.test_dir, filename)
with open(file_path, 'wb') as f:
f.write(content)
tar.add(file_path, arcname=filename)
# Create invalid tar file
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
with open(self.invalid_tar_path, 'wb') as f:
f.write(b'This is not a valid tar file')
def tearDown(self):
for filename in self.test_files:
os.remove(os.path.join(self.test_dir, filename))
os.remove(self.tar_path)
os.remove(self.invalid_tar_path)
os.rmdir(self.test_dir)
def test_tar_extract_returns_dict(self):
result = tar_extract(self.tar_path)
self.assertIsInstance(result, dict)
def test_tar_extract_correct_keys(self):
result = tar_extract(self.tar_path)
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
def test_tar_extract_content_size(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
self.assertEqual(len(result[filename]), len(content))
def test_tar_extract_content_values(self):
result = tar_extract(self.tar_path)
for filename, content in self.test_files.items():
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
def test_tar_extract_empty_file(self):
result = tar_extract(self.tar_path)
self.assertEqual(len(result['empty_file.txt']), 0)
def test_tar_extract_non_existent_file(self):
with self.assertRaises(FileNotFoundError):
tar_extract('non_existent_file.tar')
def test_tar_extract_invalid_file(self):
with self.assertRaises(tarfile.ReadError):
tar_extract(self.invalid_tar_path)
class TestTarExtractPAX(unittest.TestCase):
tar_format = tarfile.PAX_FORMAT
max_link_len = 1000_000

View File

@@ -444,6 +444,8 @@ class TestSymbolic(unittest.TestCase):
def test_div_mod_recombine_with_gcd(self):
b = Variable("b", 0, 100)
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
self.helper_test_variable(exp, 2, 1602, "((b*16)+2)")
with self.assertRaises(AssertionError):
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
@@ -525,6 +527,15 @@ class TestSymbolic(unittest.TestCase):
# not combining # TODO: can combine if one is identity element const
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
def test_symbolic_div(self):
# from symbolic arange
a = Variable("a", 1, 10)
denominator = ((a*-2)+1)
numerator = (((((a*2)+-1)*2)+1)*a)
self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)")
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
MIN, MAX = 0, 10

View File

@@ -32,6 +32,11 @@ class TestVminVmaxProperties(unittest.TestCase):
self.assertEqual(uop.vmin, -6)
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_variable_inside_special(self):
uop = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10))))
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 10)
def test_vmin_vmax_multiplication_0_inf(self):
# vmin and vmax for multiplication with a variable
x = UOp.const(dtypes.float, 0.0)

View File

@@ -88,5 +88,11 @@ class TestVerifyAST(unittest.TestCase):
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
def test_flat_const_always_valid(self):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
helper_test_verify_ast(st)
if __name__ == '__main__':
unittest.main()

View File

@@ -47,7 +47,7 @@ async function runTest() {
console.log(`error from page ${message}`),
);
const res = await page.goto("http://localhost:8000/examples/index.html");
const res = await page.goto("http://localhost:8000/examples/webgpu/efficientnet/index.html");
if (res.status() !== 200) throw new Error("Failed to load page");
const textSelector = await page.waitForSelector("#result");

View File

@@ -63,9 +63,7 @@ class Kernel:
print(self.ast)
raise e
@functools.lru_cache(None)
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
self.vars: List[Variable] = self.ast.variables()
# NOTE: this requires a specific order with the [::-1], this is likely a bug
@@ -735,7 +733,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
st = uop.arg
# everything else inherits shape
else:
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
st = src_sts[0]
if not all_same(shapes:=[x.shape for x in src_sts]):
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
raise AssertionError(f"found implicit expand {sizes} {shapes}")

View File

@@ -124,17 +124,15 @@ powers_of_two = {2**i:i for i in range(64)}
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops:
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
if Ops.SHL in ops and Ops.SHR in ops:
pat += [
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None),
(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two else None)
]
if Ops.NEG in ops:
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
@@ -241,8 +239,6 @@ arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg
sym = symbolic_flat+PatternMatcher([
# self ASSIGN is just self
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
# ASSIGN to global is just self
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
# VECTORIZE/CONST, VECTORIZE/GEP
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
@@ -288,9 +284,8 @@ sym = symbolic_flat+PatternMatcher([
# indexing, with cast or where
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
# parentless reduce
(acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
(acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
# parentless reduce # TODO: add MUL
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
# ** self folding **
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP

View File

@@ -43,6 +43,7 @@ Device = _Device()
# **************** Buffer + Allocators ****************
@dataclass(frozen=True, eq=True)
class BufferSpec:
# TODO: move device, size, dtype here?

View File

@@ -52,7 +52,7 @@ class PtrDType(DType):
def vec(self, sz:int) -> DType:
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
if sz == 1: return self # sz=1 is a scalar
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
@property
def vcount(self): return self.v
@@ -99,7 +99,7 @@ class dtypes:
@staticmethod
@functools.lru_cache(None)
def max(dtype:DType):
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
return float("inf") if dtypes.is_float(dtype) else True
@staticmethod
def finfo(dtype:DType) -> Tuple[int, int]:

View File

@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from dataclasses import dataclass
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
if any(b in depends for b in ei.bufs):
if isinstance(ei.prg, CompiledRunner):
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
depends.add(cast(Buffer, ei.bufs[0]))
pruned, onetime = partition(jit_cache,
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")

View File

@@ -3,7 +3,7 @@ from typing import Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
from tinygrad.ops import exec_alu, python_alu
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
from tinygrad.ops import MathTrait, resolve, UOp, sint, GroupOp, Ops, view_supported_devices
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from weakref import ref, ReferenceType, WeakValueDictionary
@@ -21,7 +21,6 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]
if enable_cache: lazycache[cache_key] = ret
return ret
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
class LazyBuffer(MathTrait):
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
@@ -111,16 +110,20 @@ class LazyBuffer(MathTrait):
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# TODO: applying this makes gpt2 slower
return self.base.cast(dtype, bitcast)._view(self.st)
return self.base.cast(dtype, bitcast).view(self.st)
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,))
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
@property
def const_arg(self) -> ConstType:
assert self.base.op is Ops.CONST and isinstance(self.base.arg, get_args(ConstType)), f"const_arg called on {self}"
return self.base.arg
def _copy(self, device:str) -> LazyBuffer:
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, srcs=(self,), enable_cache=False)
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
# no COPY
@@ -132,13 +135,13 @@ class LazyBuffer(MathTrait):
# const doesn't have to be copied (issues with disk tensor)
if self.is_unrealized_const():
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg).view(self.st)
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
# copy the base and apply the shapetracker on the new device
return self.base._copy(device)._view(self.st)
return self.base._copy(device).view(self.st)
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
@@ -146,7 +149,7 @@ class LazyBuffer(MathTrait):
srcs: List[LazyBuffer] = []
for s in (self,)+in_srcs:
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
srcs.append(root._view(s.base.contiguous_child[1]))
srcs.append(root.view(s.base.contiguous_child[1]))
else:
srcs.append(s)
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
@@ -181,15 +184,6 @@ class LazyBuffer(MathTrait):
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
new_shape = self.st.reduce(axis)
# TODO: this logic should move to the scheduler
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
# const folding
# TODO: fold this for symbolic?
if self.is_unrealized_unmasked_const() and all_int(self.shape):
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
@@ -213,15 +207,15 @@ class LazyBuffer(MathTrait):
# *** movement ops ***
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
def view(self, new_st:ShapeTracker) -> LazyBuffer:
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
return self.const_with_shape(0, new_st.shape)
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
def reshape(self, arg:Tuple[sint, ...]): return self.view(self.st.reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self.view(self.st.expand(arg))
def permute(self, arg:Tuple[int, ...]): return self.view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self.view(self.st.stride(arg))

View File

@@ -186,12 +186,12 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
if si.ast.op is Ops.SINK:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
out, arg = si.outputs[0], si.ast.arg
out = si.outputs[0]
if si.ast.op is Ops.COPY:
kernel_type = BufferCopy
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
kernel_type = BufferXfer
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs))
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
raise RuntimeError(f"don't know how to lower {si.ast}")

View File

@@ -3,6 +3,7 @@ from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
from tinygrad.ops import identity_element
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
from tinygrad.dtype import ConstType, ImageDType, dtypes
@@ -51,6 +52,7 @@ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# view is passthrough
if buf is not buf.base:
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
return ret
@@ -64,25 +66,24 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
# hack the underlying buffer too
buf.buffer.dtype = dtype
buf.buffer.options = None
if buf.is_realized:
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = None
elif buf.op is Ops.ASSIGN:
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
ctx.assigns.add(ubuf:=target.base.buf_uop)
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
# base is a VIEW of (BUFFER, (optional) op)
# TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this
if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st)
# ASSIGN uses the target buffer, otherwise we create a new buffer
else:
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
buffers[ubuf] = buf.buffer
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
if op is not None:
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
# keep track of scheduled ops
buf.buffer.ref(1)
ctx.lazybufs[ubuf] = buf
ctx.allbufs[ubuf] = ret
ctx.lazybufs[buf_uop] = buf
ctx.allbufs[buf_uop] = ret
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
cache[buf] = ret
buffers[ret.buf_uop] = buf.buffer
return ret
# **** AST graph rewrite
@@ -106,23 +107,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(r:UOp, v:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(v.st), unwrap(src.st)
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return src.r(r.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
if not (swizzles := [x for x in root.src if x.base is not x]): return None
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
ret = root.replace(src=new_src)
assert all_same([(x.shape, prod(x.src[0].shape)) for x in swizzles]), f"swizzles must have the same size {swizzles}"
new_input_st = ShapeTracker.from_shape(swizzles[0].src[0].shape)
ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, new_input_st) for x in root.src))
# update the ASSIGN offset to match the new shape
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,)
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(swizzles[0].shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
@@ -178,13 +174,16 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
to_si = PatternMatcher([
(UPat(Ops.VIEW, name="x"), _append_st_vars),
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),
])
# ** fusion
lazy = PatternMatcher([
# gather the metadata for this kernel
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
])
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
@@ -338,9 +337,29 @@ def _as_const(u:UOp, val:ConstType) -> UOp:
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:
# remove reduce on unmasked const
if all_int(x.shape) and x.is_unrealized_unmasked_const():
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return UOp.const(reduce.dtype, ret)
return None
ops_folding = PatternMatcher([
# op with size 0 is zero
(UPatScheduled(), lambda ctx,b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda ctx,reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_reduceop),
# CONST doesn't need COPY
(UPat(Ops.COPY, src=(UPat.var("x"),)), lambda ctx,x:x if x.is_unrealized_const() else None),
])
# ** this decides which ops get realized
@@ -366,7 +385,7 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
return to_cast.view(unwrap(view.st))
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST)
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and x.base.src[1].op is not Ops.CONST)
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
do_realize = PatternMatcher([
@@ -380,6 +399,8 @@ do_realize = PatternMatcher([
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
# realize before COPY or BUFFER_VIEW
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
# ASSIGN only needs the buffer
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda ctx,dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
])
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
@@ -402,7 +423,7 @@ break_sched = PatternMatcher([
# everything else is a VIEW of BUFFER that either realizes or fuses
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
# just load realized buffers
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
])
@track_rewrites(named=True)

View File

@@ -1,5 +1,5 @@
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
@@ -35,16 +35,21 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
R = TypeVar('R')
def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]:
@functools.wraps(func)
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
return wrapper
@accept_filename
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Dict[str, Any]]:
"""
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
"""
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
json_len = t[0:8].bitcast(dtypes.int64).item()
assert isinstance(json_len, int)
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
data_start = int.from_bytes(t[0:8].data(), "little") + 8
return t, data_start, json.loads(t[8:data_start].data().tobytes())
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
"""
Loads a .safetensor file from disk, returning the state_dict.
@@ -52,14 +57,10 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
state_dict = nn.state.safe_load("test.safetensor")
```
"""
t, json_len, metadata = safe_load_metadata(fn)
ret = {}
for k,v in metadata.items():
if k == "__metadata__": continue
dtype = safe_dtypes[v['dtype']]
sz = (v['data_offsets'][1]-v['data_offsets'][0])
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
return ret
t, data_start, metadata = safe_load_metadata(fn)
data = t[data_start:]
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
for k, v in metadata.items() if k != "__metadata__" }
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
"""
@@ -157,6 +158,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
else: v.replace(state_dict[k].to(v.device)).realize()
if consume: del state_dict[k]
@accept_filename
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
"""
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
@@ -170,7 +172,8 @@ def tar_extract(t: Tensor) -> Dict[str, Tensor]:
# torch support!
def torch_load(fn:str) -> Dict[str, Tensor]:
@accept_filename
def torch_load(t:Tensor) -> Dict[str, Tensor]:
"""
Loads a torch .pth file from disk.
@@ -178,8 +181,6 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
state_dict = nn.state.torch_load("test.pth")
```
"""
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
offsets: Dict[Union[str, int], int] = {}
lens: Dict[Union[str, int], int] = {}
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
@@ -220,8 +221,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
return intercept[name] if module_root == "torch" else super().find_class(module, name)
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
if zipfile.is_zipfile(fn):
myzip = zipfile.ZipFile(fn, 'r')
fobj = io.BufferedReader(TensorIO(t))
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
myzip = zipfile.ZipFile(fobj, 'r')
base_name = myzip.namelist()[0].split('/', 1)[0]
for n in myzip.namelist():
if n.startswith(f'{base_name}/data/'):
@@ -229,8 +233,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f'{base_name}/data.pkl') as myfile:
return TorchPickle(myfile).load()
elif tarfile.is_tarfile(fn):
with tarfile.open(fn, "r") as tar:
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
with tarfile.open(fileobj=fobj, mode="r") as tar:
storages_offset = tar.getmember('storages').offset_data
f = unwrap(tar.extractfile('storages'))
for i in range(TorchPickle(f).load()): # num_storages
@@ -245,14 +249,13 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
else:
with open(fn, "rb") as f:
pkl = TorchPickle(f)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()
pkl = TorchPickle(fobj)
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
fobj.seek(rwd)
return TorchPickle(fobj).load()
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
"""
@@ -287,6 +290,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
@accept_filename
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
"""
Loads a gguf file from a tensor.

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal, get_args
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
@@ -25,8 +25,7 @@ class SimpleMathTrait:
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
def logical_not(self): return self.ne(True)
def neg(self):
dtype: Optional[DType] = getattr(self, 'dtype', None)
assert dtype is not None, "MathTraits __neg__ requires a dtype"
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
@@ -162,6 +161,9 @@ class GroupOp:
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
# some BUFFER ops can be processed with only a view
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
# https://en.wikipedia.org/wiki/Identity_element
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
@@ -259,6 +261,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
@property
def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size
@property
def nbytes(self) -> int: return self.size*self.dtype.itemsize
# *** uop evaluation ***
@@ -288,6 +292,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
return ret.arg
@property
def const_arg(self) -> ConstType:
match self.base.op:
case Ops.CONST: ret = self.base.arg
case Ops.VIEW: ret = self.base.src[1].const_arg
case op: raise AssertionError(f"const_arg called on {op}")
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
return ret
@property
def axis_arg(self) -> Tuple[int, ...]:
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
@@ -323,11 +335,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
@staticmethod
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
def r(self, op:Ops, axis:Tuple[int, ...]):
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x), None if self.st is None or self.st.contiguous else self.st)
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
# *** from LazyBuffer ***
@@ -336,17 +348,19 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def const_with_shape(dtype:DType, val:ConstLike, shape:Tuple[sint,...]) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0)
def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.src[1].op is Ops.CONST
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views)
# *** uop movement ops ***
@property
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
# instant folding rules
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return ret
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
@@ -428,9 +442,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
if self.op is Ops.IDIV:
if s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
# don't know exact bounds, but know the sign
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
@@ -445,7 +463,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
if self.op is Ops.CONST: return self.arg, self.arg
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
return dtypes.min(self.dtype), dtypes.max(self.dtype)
@@ -483,7 +501,7 @@ python_alu: Dict[Ops, Callable] = {
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
@@ -1033,7 +1051,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
symbolic_simple = PatternMatcher([
# ** self folding **
@@ -1046,6 +1064,8 @@ symbolic_simple = PatternMatcher([
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),

View File

@@ -67,8 +67,9 @@ class WGSLRenderer(CStyleLanguage):
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), lambda ctx,b,v,g: f"select({ctx[v]}, {ctx[b]}, {ctx[g]})"),
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), \
lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0].dtype)}, {ctx[g]})"),
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
@@ -81,7 +82,8 @@ class WGSLRenderer(CStyleLanguage):
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if dt.itemsize < 4 else buffer_map[dt.base]}"
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
if not local_size: local_size = [1]

View File

@@ -1,10 +1,10 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
from tinygrad import Variable, dtypes
from tinygrad.ops import sint, Variable as VariableT
from tinygrad.ops import Variable as VariableT
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner
@@ -13,6 +13,14 @@ class HCQGraph(MultiGraphRunner):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
# Replace input buffers with variables.
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
for (j,i), input_idx in self.input_replace.items():
x = self.input_replace_to_var.setdefault((j,i), Variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in jit_cache:
@@ -23,11 +31,11 @@ class HCQGraph(MultiGraphRunner):
# Fill initial arguments.
self.ji_args: Dict[int, HCQArgsState] = {}
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
kargs_alloc: Dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, start=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
for j,ji in enumerate(jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg._prg.kernargs_alloc_size, 16)
self.ji_args[j] = ji.prg._prg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
@@ -53,8 +61,14 @@ class HCQGraph(MultiGraphRunner):
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
for j,ji in enumerate(jit_cache):
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
if is_exec_prg:
enqueue_queue = self.comp_queues[enqueue_dev]
else:
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
# Get dependencies based on input and output buffers.
@@ -101,7 +115,6 @@ class HCQGraph(MultiGraphRunner):
last_j[enqueue_queue] = j
# Build hardware queues.
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
# Create variable timeline signals for each device.
@@ -128,8 +141,7 @@ class HCQGraph(MultiGraphRunner):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
# TODO: For now sints is only for copies, should refactor to support exec as well.
enqueue_queue.copy(self._buf_addr_as_sint(j, 0, dest._buf), self._buf_addr_as_sint(j, 1, src._buf), dest.nbytes)
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
# Encode finish profile timestamp (if needed).
@@ -161,12 +173,7 @@ class HCQGraph(MultiGraphRunner):
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
if (var:=self.input_replace_to_var.get((j,i))) is not None: hcq_var_vals[var] = input_rawbuffers[input_idx]._buf.va_addr
else: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
# Update var_vals
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
for dev in self.devices:
self.comp_queues[dev].submit(dev, hcq_var_vals)
@@ -191,10 +198,6 @@ class HCQGraph(MultiGraphRunner):
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
def _buf_addr_as_sint(self, j:int, i:int, buf:HCQBuffer) -> sint:
if (j, i) not in self.input_replace: return buf.va_addr
return self.input_replace_to_var.setdefault((j, i), Variable(f"input_{j}_{i}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
def __del__(self):
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Tuple, List, Any, Optional
from typing import Tuple, List, Any, Optional, cast
import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, array, contextlib, sys
assert sys.platform != 'win32'
from dataclasses import dataclass
@@ -80,6 +80,8 @@ class AMDComputeQueue(HWQueue):
return self
def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
self.bind_args_state(args_state)
self.acquire_mem(gli=0, gl2=0)
if prg.enable_private_segment_sgpr:
@@ -262,15 +264,15 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.host: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
if options.host: return self.dev._gpu_alloc(size, host=True)
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, uncached=True)
return self.dev._gpu_alloc(size, cpu_access=options.cpu_access)
def _free(self, opaque, options:BufferSpec):
self.dev.synchronize()
self.dev._gpu_free(opaque)
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if hasattr(buf, '_base') else buf)
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if buf._base is not None else buf)
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
@@ -289,45 +291,48 @@ class AMDDevice(HCQCompiled):
signals_pool:List[int] = []
gpus:List[pathlib.Path] = []
def _gpu_map(self, mem):
if self.gpu_id in getattr(mem, "mapped_gpu_ids", []): return
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_id])
c_gpus = (ctypes.c_int32 * len(mem.mapped_gpu_ids))(*mem.mapped_gpu_ids)
stm = kfd.AMDKFD_IOC_MAP_MEMORY_TO_GPU(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(c_gpus),
n_devices=len(mem.mapped_gpu_ids))
assert stm.n_success == len(mem.mapped_gpu_ids)
def _gpu_map(self, mem:HCQBuffer):
if self.gpu_id in getattr(mem.meta, "mapped_gpu_ids", []): return
mem.meta.__setattr__("mapped_gpu_ids", getattr(mem.meta, "mapped_gpu_ids", []) + [self.gpu_id])
c_gpus = (ctypes.c_int32 * len(mem.meta.mapped_gpu_ids))(*mem.meta.mapped_gpu_ids)
stm = kfd.AMDKFD_IOC_MAP_MEMORY_TO_GPU(self.kfd, handle=mem.meta.handle, device_ids_array_ptr=ctypes.addressof(c_gpus),
n_devices=len(mem.meta.mapped_gpu_ids))
assert stm.n_success == len(mem.meta.mapped_gpu_ids)
def _gpu_alloc(self, size:int, flags:int, uncached=False, public=False, map_to_gpu=True):
flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
if uncached: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_COHERENT | kfd.KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED
if public: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_PUBLIC
if flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR:
buf = addr = libc.mmap(0, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|mmap.MAP_ANONYMOUS, -1, 0)
else:
buf, addr = 0, libc.mmap(0, size, 0, mmap.MAP_PRIVATE|mmap.MAP_ANONYMOUS|MAP_NORESERVE, -1, 0)
def _gpu_alloc(self, size:int, host=False, uncached=False, cpu_access=False) -> HCQBuffer:
flags = kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
if uncached: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_COHERENT | kfd.KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED | kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT
else: flags |= (kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR if host else kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
if cpu_access or host: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_PUBLIC
if host: buf = addr = libc.mmap(0, size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
else: buf, addr = 0, libc.mmap(0, size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE, -1, 0)
assert addr != 0xffffffffffffffff
try: mem = kfd.AMDKFD_IOC_ALLOC_MEMORY_OF_GPU(self.kfd, va_addr=addr, size=size, base=addr, length=size, gpu_id=self.gpu_id,
flags=flags, mmap_offset=buf)
flags=flags, mmap_offset=buf, cpu_addr=addr)
except OSError as e:
if e.errno == errno.EINVAL and (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) and public:
if e.errno == errno.EINVAL and (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) and cpu_access:
raise MemoryError("Cannot allocate host-visible VRAM. Ensure the resizable BAR option is enabled on your system.") from e
if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory: no memory is available.") from e
raise
if not (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR):
buf = libc.mmap(mem.va_addr, mem.size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|MAP_FIXED, self.drm_fd, mem.mmap_offset)
if not host:
buf = libc.mmap(mem.va_addr, mem.size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_FIXED, self.drm_fd, mem.mmap_offset)
assert addr == buf == mem.va_addr
if map_to_gpu: self._gpu_map(mem)
return mem
def _gpu_free(self, mem):
if len(gpus:=getattr(mem, "mapped_gpu_ids", [])):
self._gpu_map(hcqbuf:=HCQBuffer(mem.va_addr, mem.size, meta=mem))
return hcqbuf
def _gpu_free(self, mem:HCQBuffer):
if len(gpus:=getattr(mem.meta, "mapped_gpu_ids", [])):
c_gpus = (ctypes.c_int32 * len(gpus))(*gpus)
stm = kfd.AMDKFD_IOC_UNMAP_MEMORY_FROM_GPU(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(c_gpus), n_devices=len(gpus))
stm = kfd.AMDKFD_IOC_UNMAP_MEMORY_FROM_GPU(self.kfd, handle=mem.meta.handle, device_ids_array_ptr=ctypes.addressof(c_gpus), n_devices=len(gpus))
assert stm.n_success == len(gpus)
libc.munmap(mem.va_addr, mem.size)
kfd.AMDKFD_IOC_FREE_MEMORY_OF_GPU(self.kfd, handle=mem.handle)
kfd.AMDKFD_IOC_FREE_MEMORY_OF_GPU(self.kfd, handle=mem.meta.handle)
def __init__(self, device:str=""):
if AMDDevice.kfd == -1:
@@ -350,10 +355,10 @@ class AMDDevice(HCQCompiled):
kfd.AMDKFD_IOC_ACQUIRE_VM(AMDDevice.kfd, drm_fd=self.drm_fd, gpu_id=self.gpu_id)
if AMDDevice.event_page is None:
AMDDevice.signals_page = self._gpu_alloc(16 * 65536, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
AMDDevice.event_page = self._gpu_alloc(0x8000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
AMDDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
AMDDevice.event_page = self._gpu_alloc(0x8000, uncached=True)
AMDDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, AMDDevice.signals_page.size, 16)]
kfd.AMDKFD_IOC_CREATE_EVENT(AMDDevice.kfd, event_page_offset=AMDDevice.event_page.handle)
kfd.AMDKFD_IOC_CREATE_EVENT(AMDDevice.kfd, event_page_offset=AMDDevice.event_page.meta.handle)
else:
self._gpu_map(AMDDevice.signals_page)
self._gpu_map(AMDDevice.event_page)
@@ -365,7 +370,7 @@ class AMDDevice(HCQCompiled):
# <gfx103 requires alignment of 1024, >=gfx11 requires 256
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024)
self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len
self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
self.scratch = self._gpu_alloc(self.scratch_len)
self.has_scratch_base_registers = self.target >= 110000
engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine']
waves = wave_scratch_len // (256 if self.target >= 110000 else 1024)
@@ -396,11 +401,10 @@ class AMDDevice(HCQCompiled):
AMDSignal, AMDComputeQueue, AMDCopyQueue)
def _alloc_queue(self, queue_type, ring_size, ctx_save_restore_size=None, eop_buffer_size=None, ctl_stack_size=0) -> AMDQueueDesc:
gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
ring = self._gpu_alloc(ring_size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
cwsr_ctx = self._gpu_alloc(round_up(ctx_save_restore_size + self.debug_memory_size, mmap.PAGESIZE),
kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) if ctx_save_restore_size else None
eop_buffer = self._gpu_alloc(eop_buffer_size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) if eop_buffer_size else None
gart = self._gpu_alloc(0x1000, uncached=True)
ring = self._gpu_alloc(ring_size, uncached=True)
cwsr_ctx = self._gpu_alloc(round_up(ctx_save_restore_size + self.debug_memory_size, mmap.PAGESIZE)) if ctx_save_restore_size else None
eop_buffer = self._gpu_alloc(eop_buffer_size) if eop_buffer_size else None
queue = kfd.AMDKFD_IOC_CREATE_QUEUE(AMDDevice.kfd, ring_base_address=ring.va_addr, ring_size=ring.size, gpu_id=self.gpu_id,
queue_type=queue_type, queue_percentage=kfd.KFD_MAX_QUEUE_PERCENTAGE, queue_priority=kfd.KFD_MAX_QUEUE_PRIORITY,
eop_buffer_address=eop_buffer.va_addr if eop_buffer else 0, eop_buffer_size=eop_buffer.size if eop_buffer else 0, ctl_stack_size=ctl_stack_size,
@@ -411,7 +415,7 @@ class AMDDevice(HCQCompiled):
self.doorbells_base = queue.doorbell_offset & (~0x1fff) # doorbell is two pages
self.doorbells = libc.mmap(0, 0x2000, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, AMDDevice.kfd, self.doorbells_base)
return AMDQueueDesc(ring=to_mv(ring.va_addr, ring_size).cast("I"),
return AMDQueueDesc(ring=to_mv(cast(int, ring.va_addr), ring_size).cast("I"),
read_ptr=to_mv(queue.read_pointer_address, 8).cast("Q"), write_ptr=to_mv(queue.write_pointer_address, 8).cast("Q"),
doorbell=to_mv(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8).cast("Q"))

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import os, pathlib, struct, ctypes, tempfile, functools
from typing import List, Any, Union, Tuple, cast
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T, init_c_struct_t
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
@@ -45,10 +45,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
def to_struct(*t: int, _type: type = ctypes.c_ulong):
class Struct(ctypes.Structure): pass
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
return Struct(*t)
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
def wait_check(cbuf: Any):
msg(cbuf, "waitUntilCompleted")
@@ -112,9 +109,8 @@ class MetalProgram:
if lib[:4] == b"MTLB":
# binary metal library
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
error_library_creation = objc_instance()
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
error_check(error_library_creation)
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
error_check(error_lib)
else:
# metal source. rely on OS caching
try: self.library = metal_src_to_library(self.dev, lib.decode())
@@ -137,7 +133,7 @@ class MetalProgram:
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "setComputePipelineState:", self.pipeline_state)
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(encoder, "endEncoding")
msg(command_buffer, "commit")
@@ -178,9 +174,7 @@ class MetalAllocator(LRUAllocator):
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def _as_buffer(self, src:MetalBuffer) -> memoryview:
self.dev.synchronize()
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
return memoryview(array).cast("B")[src.offset:]
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src
def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src)
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)

View File

@@ -3,7 +3,7 @@ import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys
assert sys.platform != 'win32'
from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
from dataclasses import dataclass
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
from tinygrad.ops import sint
from tinygrad.device import BufferSpec
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
@@ -71,8 +71,6 @@ def make_qmd_struct_type():
qmd_struct_t = make_qmd_struct_type()
assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
def nvmethod(subc, mthd, size, typ=2): return (typ << 28) | (size << 16) | (subc << 13) | (mthd >> 2)
class NVSignal(HCQSignal):
def __init__(self, base_addr:Optional[int]=None, **kwargs):
super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8)
@@ -88,18 +86,19 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
def __del__(self):
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
def nvm(self, subchannel, mthd, *args, typ=2): self.q((typ << 28) | (len(args) << 16) | (subchannel << 13) | (mthd >> 2), *args)
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
if compute_class: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_OBJECT, 1), compute_class)
if copy_class: self.q(nvmethod(4, nv_gpu.NVC6C0_SET_OBJECT, 1), copy_class)
if local_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_WINDOW_A, 2), *data64(local_mem_window))
if shared_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, 2), *data64(shared_mem_window))
if local_mem: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, 2), *data64(local_mem))
if local_mem_tpc_bytes: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, 3), *data64(local_mem_tpc_bytes), 0xff)
if compute_class: self.nvm(1, nv_gpu.NVC6C0_SET_OBJECT, compute_class)
if copy_class: self.nvm(4, nv_gpu.NVC6C0_SET_OBJECT, copy_class)
if local_mem_window: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_WINDOW_A, *data64(local_mem_window))
if shared_mem_window: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, *data64(shared_mem_window))
if local_mem: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, *data64(local_mem))
if local_mem_tpc_bytes: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, *data64(local_mem_tpc_bytes), 0xff)
return self
def wait(self, signal:NVSignal, value:sint=0):
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
(3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value), (3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
self.active_qmd = None
return self
@@ -117,14 +116,9 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
def _submit_to_gpfifo(self, dev:NVDevice, gpfifo:GPFifo):
if dev == self.binded_device: cmdq_addr = self.hw_page.va_addr
else:
if dev.cmdq_wptr + len(self._q) * 4 > dev.cmdq_page.size:
assert (gpfifo.ring[gpfifo.controls.GPGet] & 0xFFFFFFFFFC) >= dev.cmdq_page.va_addr + len(self._q) * 4 or \
gpfifo.controls.GPGet == gpfifo.controls.GPPut, "cmdq overrun"
dev.cmdq_wptr = 0
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self._q)] = array.array('I', self._q)
cmdq_addr = dev.cmdq_page.va_addr+dev.cmdq_wptr
dev.cmdq_wptr += len(self._q) * 4
cmdq_addr = dev.cmdq_allocator.alloc(len(self._q) * 4)
cmdq_wptr = (cmdq_addr - dev.cmdq_page.va_addr) // 4
dev.cmdq[cmdq_wptr : cmdq_wptr + len(self._q)] = array.array('I', self._q)
gpfifo.ring[gpfifo.put_value % gpfifo.entries_count] = (cmdq_addr//4 << 2) | (len(self._q) << 42) | (1 << 41)
gpfifo.controls.GPPut = (gpfifo.put_value + 1) % gpfifo.entries_count
@@ -133,11 +127,13 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
class NVComputeQueue(NVCommandQueue):
def memory_barrier(self):
self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, (1 << 12) | (1 << 4) | (1 << 0))
self.active_qmd = None
return self
def exec(self, prg:NVProgram, args_state:NVArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
self.bind_args_state(args_state)
ctypes.memmove(qmd_addr:=(args_state.ptr + round_up(prg.constbufs[0][1], 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
assert qmd_addr < (1 << 40), f"large qmd addr {qmd_addr:x}"
@@ -148,8 +144,8 @@ class NVComputeQueue(NVCommandQueue):
qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr)
if self.active_qmd is None:
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_PCAS_A, 0x1), qmd_addr >> 8)
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 0x1), 9)
self.nvm(1, nv_gpu.NVC6C0_SEND_PCAS_A, qmd_addr >> 8)
self.nvm(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 9)
else:
self.active_qmd.dependent_qmd0_pointer = qmd_addr >> 8
self.active_qmd.dependent_qmd0_action = 1
@@ -168,9 +164,9 @@ class NVComputeQueue(NVCommandQueue):
self.bind_sints(value, struct=self.active_qmd, start_field=f'release{i}_payload', fmt='Q')
return self
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.q(nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0)
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
self.nvm(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 0x0)
self.active_qmd = None
return self
@@ -178,14 +174,14 @@ class NVComputeQueue(NVCommandQueue):
class NVCopyQueue(NVCommandQueue):
def copy(self, dest:sint, src:sint, copy_size:int):
self.q(nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest))
self.q(nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size)
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
self.nvm(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, *data64(src), *data64(dest))
self.nvm(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, copy_size)
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
return self
def signal(self, signal:NVSignal, value:sint=0):
self.q(nvmethod(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, 3), *data64(signal.value_addr), value)
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x14)
self.nvm(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, *data64(signal.value_addr), value)
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x14)
return self
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
@@ -267,14 +263,14 @@ class NVProgram(HCQProgram):
class NVAllocator(HCQAllocator['NVDevice']):
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
if options.host: return self.dev._gpu_alloc(size, host=True, tag="user host memory")
return self.dev._gpu_alloc(size, cpu_access=options.cpu_access, tag=f"user memory ({options})")
def _free(self, opaque, options:BufferSpec):
def _free(self, opaque:HCQBuffer, options:BufferSpec):
self.dev.synchronize()
self.dev._gpu_free(opaque)
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if hasattr(buf, '_base') else buf)
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if buf._base is not None else buf)
@dataclass
class GPFifo:
@@ -292,8 +288,12 @@ class NVDevice(HCQCompiled[NVSignal]):
gpus_info: Union[List, ctypes.Array] = []
signals_page: Any = None
signals_pool: List[int] = []
low_uvm_vaddr: int = 0x1000000000 # 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
uvm_vaddr: int = 0x2000000000 # 0x2000000000+
# TODO: Need a proper allocator for va addresses
# 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
# VA space is 48bits.
low_uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=0x1000000000, start=0x1000000000, wrap=False)
uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=(1 << 48) - 1, start=0x2000000000, wrap=False)
host_object_enumerator: int = 0x1000
def _new_gpu_fd(self):
@@ -311,79 +311,70 @@ class NVDevice(HCQCompiled[NVSignal]):
os.close(fd_dev)
return res
def _gpu_alloc(self, size:int, contig=False, huge_page=False, va_addr=None, map_to_cpu=False, map_flags=0, tag=""):
size = round_up(size, align:=((2 << 20) if huge_page else (4 << 10)))
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, alignment=align, offset=0, limit=size-1, format=6, size=size,
attr=(((nv_gpu.NVOS32_ATTR_PAGE_SIZE_HUGE << 23) if huge_page else 0) |
((nv_gpu.NVOS32_ATTR_PHYSICALITY_CONTIGUOUS if contig else nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS) << 27)),
attr2=((nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC << 0) | (nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_YES << 2) |
((nv_gpu.NVOS32_ATTR2_PAGE_SIZE_HUGE_2MB << 20) if huge_page else 0)),
flags=(nv_gpu.NVOS32_ALLOC_FLAGS_ALIGNMENT_FORCE | nv_gpu.NVOS32_ALLOC_FLAGS_PERSISTENT_VIDMEM | nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED |
nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED))
mem_handle = rm_alloc(self.fd_ctl, nv_gpu.NV1_MEMORY_USER, self.root, self.nvdevice, alloc_params).hObjectNew
def _gpu_alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, map_flags=0, tag="") -> HCQBuffer:
# Uncached memory is "system". Use huge pages only for gpu memory.
page_size = (4 << 10) if uncached or host else ((2 << 20) if size >= (8 << 20) else (4 << 10))
size = round_up(size, page_size)
va_addr = self._alloc_gpu_vaddr(size, alignment=page_size, force_low=cpu_access)
if va_addr is None: va_addr = self._alloc_gpu_vaddr(size, alignment=align, force_low=map_to_cpu)
if map_to_cpu: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags)
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=map_to_cpu, tag=tag)
if host:
va_addr = libc.mmap(va_addr, size, mmap.PROT_READ | mmap.PROT_WRITE, MAP_FIXED | mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
def _gpu_system_alloc(self, size:int, va_addr=None, map_to_cpu=False, map_flags=0, tag=""):
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, type=13,
attr=(nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS << 27) | (nv_gpu.NVOS32_ATTR_LOCATION_PCI << 25),
attr2=(nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC << 0) | (nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_NO << 2),
flags=(nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED |
nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED), format=6, size=size, alignment=(4<<10), offset=0, limit=size-1)
mem_handle = rm_alloc(self.fd_ctl, nv_gpu.NV1_MEMORY_SYSTEM, self.root, self.nvdevice, alloc_params).hObjectNew
flags = (nv_gpu.NVOS02_FLAGS_PHYSICALITY_NONCONTIGUOUS << 4) | (nv_gpu.NVOS02_FLAGS_COHERENCY_CACHED << 12) \
| (nv_gpu.NVOS02_FLAGS_MAPPING_NO_MAP << 30)
if va_addr is None: va_addr = self._alloc_gpu_vaddr(size, force_low=True)
if map_to_cpu: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags, system=True)
NVDevice.host_object_enumerator += 1
made = nv_gpu.nv_ioctl_nvos02_parameters_with_fd(params=nv_gpu.NVOS02_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, flags=flags,
hObjectNew=NVDevice.host_object_enumerator, hClass=nv_gpu.NV01_MEMORY_SYSTEM_OS_DESCRIPTOR, pMemory=va_addr, limit=size-1), fd=-1)
nv_iowr(self.fd_dev, nv_gpu.NV_ESC_RM_ALLOC_MEMORY, made)
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=map_to_cpu, tag=tag)
if made.params.status != 0: raise RuntimeError(f"host alloc returned {get_error_str(made.params.status)}")
mem_handle = made.params.hObjectNew
else:
attr = ((nv_gpu.NVOS32_ATTR_PHYSICALITY_CONTIGUOUS if contiguous else nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS) << 27) \
| (nv_gpu.NVOS32_ATTR_PAGE_SIZE_HUGE if page_size > 0x1000 else 0) << 23 | ((nv_gpu.NVOS32_ATTR_LOCATION_PCI if uncached else 0) << 25)
def _gpu_host_alloc(self, size, tag=""):
va_base = self._alloc_gpu_vaddr(aligned_sz:=round_up(size, 4 << 10))
mapped_addr = libc.mmap(va_base, aligned_sz, mmap.PROT_READ|mmap.PROT_WRITE, MAP_FIXED|mmap.MAP_SHARED|mmap.MAP_ANONYMOUS, -1, 0)
assert mapped_addr == va_base, f"Not mmaped at correct address {va_base=} != {mapped_addr=}"
attr2 = ((nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_NO if uncached else nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_YES) << 2) \
| ((nv_gpu.NVOS32_ATTR2_PAGE_SIZE_HUGE_2MB if page_size > 0x1000 else 0) << 20) | nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC
NVDevice.host_object_enumerator += 1
flags = ((nv_gpu.NVOS02_FLAGS_PHYSICALITY_NONCONTIGUOUS << 4) | (nv_gpu.NVOS02_FLAGS_COHERENCY_CACHED << 12) |
(nv_gpu.NVOS02_FLAGS_MAPPING_NO_MAP << 30))
made = nv_gpu.nv_ioctl_nvos02_parameters_with_fd(params=nv_gpu.NVOS02_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, flags=flags,
hObjectNew=NVDevice.host_object_enumerator, hClass=nv_gpu.NV01_MEMORY_SYSTEM_OS_DESCRIPTOR, pMemory=va_base, limit=aligned_sz-1), fd=-1)
nv_iowr(self.fd_dev, nv_gpu.NV_ESC_RM_ALLOC_MEMORY, made)
fl = nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED | nv_gpu.NVOS32_ALLOC_FLAGS_ALIGNMENT_FORCE \
| nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | (nv_gpu.NVOS32_ALLOC_FLAGS_PERSISTENT_VIDMEM if not uncached else 0)
if made.params.status != 0: raise RuntimeError(f"_map_to_gpu returned {get_error_str(made.params.status)}")
return self._gpu_uvm_map(va_base, aligned_sz, made.params.hObjectNew, has_cpu_mapping=True, tag=tag)
alloc_func = nv_gpu.NV1_MEMORY_SYSTEM if uncached else nv_gpu.NV1_MEMORY_USER
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, alignment=page_size, offset=0, limit=size-1, format=6, size=size,
type=nv_gpu.NVOS32_TYPE_NOTIFIER if uncached else nv_gpu.NVOS32_TYPE_IMAGE, attr=attr, attr2=attr2, flags=fl)
mem_handle = rm_alloc(self.fd_ctl, alloc_func, self.root, self.nvdevice, alloc_params).hObjectNew
def _gpu_free(self, mem):
if mem.hMemory > NVDevice.host_object_enumerator: # not a host object, clear phys mem.
made = nv_gpu.NVOS00_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, hObjectOld=mem.hMemory)
if cpu_access: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags, system=uncached)
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=cpu_access or host, tag=tag)
def _gpu_free(self, mem:HCQBuffer):
if mem.meta.hMemory > NVDevice.host_object_enumerator: # not a host object, clear phys mem.
made = nv_gpu.NVOS00_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, hObjectOld=mem.meta.hMemory)
nv_iowr(self.fd_ctl, nv_gpu.NV_ESC_RM_FREE, made)
if made.status != 0: raise RuntimeError(f"_gpu_free returned {get_error_str(made.status)}")
self._debug_mappings.pop((mem.va_addr, mem.size))
uvm.free(self.fd_uvm, base=mem.va_addr, length=mem.size)
if mem.has_cpu_mapping: libc.munmap(mem.va_addr, mem.size)
self._debug_mappings.pop((cast(int, mem.va_addr), mem.size))
uvm.free(self.fd_uvm, base=cast(int, mem.va_addr), length=mem.size)
if mem.meta.has_cpu_mapping: libc.munmap(cast(int, mem.va_addr), mem.size)
def _gpu_uvm_map(self, va_base, size, mem_handle, create_range=True, has_cpu_mapping=False, tag="") -> nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS:
def _gpu_uvm_map(self, va_base, size, mem_handle, create_range=True, has_cpu_mapping=False, tag="") -> HCQBuffer:
if create_range: uvm.create_external_range(self.fd_uvm, base=va_base, length=size)
attrs = (nv_gpu.struct_c__SA_UvmGpuMappingAttributes*256)(nv_gpu.struct_c__SA_UvmGpuMappingAttributes(gpuUuid=self.gpu_uuid, gpuMappingType=1))
# NOTE: va_addr is set to make rawbufs compatable with HCQBuffer protocol.
self._debug_mappings[(va_base, size)] = tag
return uvm.map_external_allocation(self.fd_uvm, base=va_base, length=size, rmCtrlFd=self.fd_ctl, hClient=self.root, hMemory=mem_handle,
gpuAttributesCount=1, perGpuAttributes=attrs, va_addr=va_base, size=size, mapped_gpu_ids=[self.gpu_uuid], has_cpu_mapping=has_cpu_mapping)
return HCQBuffer(va_base, size, meta=uvm.map_external_allocation(self.fd_uvm, base=va_base, length=size, rmCtrlFd=self.fd_ctl, hClient=self.root,
hMemory=mem_handle, gpuAttributesCount=1, perGpuAttributes=attrs, mapped_gpu_ids=[self.gpu_uuid], has_cpu_mapping=has_cpu_mapping))
def _gpu_map(self, mem):
if self.gpu_uuid in mem.mapped_gpu_ids: return
mem.mapped_gpu_ids.append(self.gpu_uuid)
self._gpu_uvm_map(mem.va_addr, mem.size, mem.hMemory, create_range=False, tag="p2p mem")
def _gpu_map(self, mem:HCQBuffer):
if self.gpu_uuid in mem.meta.mapped_gpu_ids: return
mem.meta.mapped_gpu_ids.append(self.gpu_uuid)
self._gpu_uvm_map(mem.va_addr, mem.size, mem.meta.hMemory, create_range=False, tag="p2p mem")
def _alloc_gpu_vaddr(self, size, alignment=(4 << 10), force_low=False):
if force_low:
NVDevice.low_uvm_vaddr = (res_va:=round_up(NVDevice.low_uvm_vaddr, alignment)) + size
assert NVDevice.low_uvm_vaddr < 0x2000000000, "Exceed low vm addresses"
else: NVDevice.uvm_vaddr = (res_va:=round_up(NVDevice.uvm_vaddr, alignment)) + size
return res_va
return NVDevice.low_uvm_vaddr_allocator.alloc(size, alignment) if force_low else NVDevice.uvm_vaddr_allocator.alloc(size, alignment)
def _setup_nvclasses(self):
classlist = memoryview(bytearray(100 * 4)).cast('I')
@@ -441,14 +432,14 @@ class NVDevice(HCQCompiled[NVSignal]):
except RuntimeError as e: raise RuntimeError(str(e) + f". Make sure GPUs #{self.gpu_minor} & #{dev.gpu_minor} have P2P enabled between.") from e
if NVDevice.signals_page is None:
NVDevice.signals_page = self._gpu_system_alloc(16 * 65536, map_to_cpu=True)
NVDevice.signals_page = self._gpu_alloc(16 * 65536, cpu_access=True, uncached=True)
NVDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, NVDevice.signals_page.size, 16)]
else: self._gpu_map(NVDevice.signals_page)
channel_params = nv_gpu.NV_CHANNEL_GROUP_ALLOCATION_PARAMETERS(engineType=nv_gpu.NV2080_ENGINE_TYPE_GRAPHICS)
channel_group = rm_alloc(self.fd_ctl, nv_gpu.KEPLER_CHANNEL_GROUP_A, self.root, self.nvdevice, channel_params).hObjectNew
gpfifo_area = self._gpu_alloc(0x200000, contig=True, huge_page=True, map_to_cpu=True, map_flags=0x10d0000, tag="gpfifo")
gpfifo_area = self._gpu_alloc(0x200000, contiguous=True, cpu_access=True, map_flags=0x10d0000, tag="gpfifo")
ctxshare_params = nv_gpu.NV_CTXSHARE_ALLOCATION_PARAMETERS(hVASpace=vaspace, flags=nv_gpu.NV_CTXSHARE_ALLOCATION_FLAGS_SUBCONTEXT_ASYNC)
ctxshare = rm_alloc(self.fd_ctl, nv_gpu.FERMI_CONTEXT_SHARE_A, self.root, channel_group, ctxshare_params).hObjectNew
@@ -458,9 +449,9 @@ class NVDevice(HCQCompiled[NVSignal]):
rmctrl.gpfifo_schedule(self.fd_ctl, self.root, channel_group, bEnable=1)
self.cmdq_page: nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS = self._gpu_alloc(0x200000, map_to_cpu=True, huge_page=True, tag="cmdq")
self.cmdq: memoryview = to_mv(self.cmdq_page.va_addr, 0x200000).cast("I")
self.cmdq_wptr: int = 0 # in bytes
self.cmdq_page:HCQBuffer = self._gpu_alloc(0x200000, cpu_access=True, tag="cmdq")
self.cmdq_allocator = BumpAllocator(size=self.cmdq_page.size, start=cast(int, self.cmdq_page.va_addr), wrap=True)
self.cmdq: memoryview = to_mv(cast(int, self.cmdq_page.va_addr), 0x200000).cast("I")
self.num_gpcs, self.num_tpc_per_gpc, self.num_sm_per_tpc, self.max_warps_per_sm, self.sm_version = self._query_gpu_info('num_gpcs',
'num_tpc_per_gpc', 'num_sm_per_tpc', 'max_warps_per_sm', 'sm_version')
@@ -473,10 +464,10 @@ class NVDevice(HCQCompiled[NVSignal]):
self._setup_gpfifos()
def _new_gpu_fifo(self, gpfifo_area, ctxshare, channel_group, offset=0, entries=0x400, enable_debug=False) -> GPFifo:
notifier = self._gpu_system_alloc(48 << 20)
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(hObjectError=notifier.hMemory, hObjectBuffer=gpfifo_area.hMemory,
notifier = self._gpu_alloc(48 << 20, uncached=True)
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(hObjectError=notifier.meta.hMemory, hObjectBuffer=gpfifo_area.meta.hMemory,
gpFifoOffset=gpfifo_area.va_addr+offset, gpFifoEntries=entries, hContextShare=ctxshare,
hUserdMemory=(ctypes.c_uint32*8)(gpfifo_area.hMemory), userdOffset=(ctypes.c_uint64*8)(entries*8+offset))
hUserdMemory=(ctypes.c_uint32*8)(gpfifo_area.meta.hMemory), userdOffset=(ctypes.c_uint64*8)(entries*8+offset))
gpfifo = rm_alloc(self.fd_ctl, nv_gpu.AMPERE_CHANNEL_GPFIFO_A, self.root, channel_group, params).hObjectNew
comp = rm_alloc(self.fd_ctl, self.compute_class, self.root, gpfifo, None).hObjectNew
rm_alloc(self.fd_ctl, nv_gpu.AMPERE_DMA_COPY_B, self.root, gpfifo, None)

View File

@@ -4,11 +4,11 @@ assert sys.platform != 'win32'
from types import SimpleNamespace
from typing import Tuple, List, Any, cast, Optional
from tinygrad.device import BufferSpec
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
from tinygrad.runtime.autogen import kgsl, adreno, libc
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, fromimport
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
@@ -86,7 +86,7 @@ class QCOMComputeQueue(HWQueue):
return self
def _build_gpu_command(self, dev:QCOMDevice, hw_addr=None):
to_mv((hw_page_addr:=hw_addr or dev._alloc_cmd_buf(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
to_mv((hw_page_addr:=hw_addr or dev.cmd_buf_allocator.alloc(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
obj = kgsl.struct_kgsl_command_object(gpuaddr=hw_page_addr, size=len(self._q) * 4, flags=kgsl.KGSL_CMDLIST_IB)
submit_req = kgsl.struct_kgsl_gpu_command(cmdlist=ctypes.addressof(obj), numcmds=1, context_id=dev.ctx,
cmdsize=ctypes.sizeof(kgsl.struct_kgsl_command_object))
@@ -105,6 +105,8 @@ class QCOMComputeQueue(HWQueue):
dev.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(dev.fd, __payload=submit_req).timestamp
def exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
self.bind_args_state(args_state)
def cast_int(x, ceil=False): return (math.ceil(x) if ceil else int(x)) if isinstance(x, float) else x
global_size_mp = [cast_int(g*l) for g,l in zip(global_size, local_size)]
@@ -147,7 +149,7 @@ class QCOMComputeQueue(HWQueue):
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
*data64_le(args_state.ptr + args_state.prg.samp_off))
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev._border_color_base()))
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev.border_color_buf.va_addr))
if args_state.prg.tex_cnt > 0:
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
@@ -179,17 +181,13 @@ class QCOMArgsState(HCQArgsState):
for cnst_val, cnst_off, cnst_sz in prg.consts_info: to_mv(self.ptr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
for i, b in enumerate(cast(List[QCOMBuffer], bufs)):
if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc)
elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo)
else: self.update_buffer(i, b)
for i, v in enumerate(vals): self.update_var(i, v)
for i, b in enumerate(bufs):
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo
to_mv(self.ptr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
self.bind_sints_to_ptr(b.va_addr, ptr=self.ptr + self.buf_info[i].offset + (0 if self.buf_info[i].type is BUFTYPE_BUF else 16), fmt='Q')
def update_buffer(self, index:int, buf:HCQBuffer):
if self.buf_info[index].type is not BUFTYPE_BUF: self.args_view[self.buf_info[index].offset//8 + 2] = buf.va_addr
else: self.args_view[self.buf_info[index].offset//8] = buf.va_addr
def update_var(self, index:int, val:int): self.args_view[self.args_info[index].offset//8] = val
for i, v in enumerate(vals): self.bind_sints_to_ptr(v, ptr=self.ptr + self.args_info[i].offset, fmt='I')
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
@@ -198,7 +196,7 @@ class QCOMProgram(HCQProgram):
self._parse_lib()
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferSpec(cpu_access=True, nolru=True))
to_mv(self.lib_gpu.va_addr, self.image_size)[:] = self.image
to_mv(cast(int, self.lib_gpu.va_addr), self.image_size)[:] = self.image
self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
self.pvtmem_size_total: int = self.pvtmem_size_per_item * 128 * 2
@@ -269,15 +267,13 @@ class QCOMProgram(HCQProgram):
def __del__(self):
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
class QCOMBuffer(HCQBuffer):
def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
self.va_addr, self.size, self.info, self.mapped = va_addr, size, info, mapped
class QCOMTextureInfo:
def __init__(self, pitch:int, real_stride:int, desc:List[int], ibo:List[int]):
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
# Texture specific definitions
self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
class QCOMAllocator(HCQAllocator):
class QCOMAllocator(HCQAllocatorBase):
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
# Recalculate real size for texture
if options.image is not None:
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
@@ -286,22 +282,18 @@ class QCOMAllocator(HCQAllocator):
granularity = 128 if options.image.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
size = pitch * imgh
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
else: texture = self.dev._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE)
texture.pitch, texture.real_stride = pitch, real_stride
buf = HCQBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
if options.image is not None:
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
return texture
return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
return buf
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
while src_off < src_size:
@@ -309,17 +301,18 @@ class QCOMAllocator(HCQAllocator):
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
def _copyin(self, dest:HCQBuffer, src:memoryview):
if (qd:=cast(QCOMBuffer, dest)).pitch is not None: self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
self._do_copy(mv_address(src), dest.va_addr, src.nbytes, stride, stride, pitch)
def _copyout(self, dest:memoryview, src:HCQBuffer):
self.dev.synchronize()
if (qs:=cast(QCOMBuffer, src)).pitch is not None: self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
self._do_copy(src.va_addr, mv_address(dest), src.size, stride, pitch, stride)
def _as_buffer(self, src:HCQBuffer) -> memoryview:
self.dev.synchronize()
return to_mv(src.va_addr, src.size)
return to_mv(cast(int, src.va_addr), src.size)
def _free(self, opaque, options:BufferSpec):
self.dev.synchronize()
@@ -333,32 +326,35 @@ class QCOMDevice(HCQCompiled):
def __init__(self, device:str=""):
self.fd = os.open('/dev/kgsl-3d0', os.O_RDWR)
QCOMDevice.dummy_addr = self._gpu_alloc(0x1000).va_addr
QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr)
QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
QCOMDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, self.signals_page.size, 16)]
info, self.ctx, self.cmd_buf, self.cmd_buf_ptr, self.last_cmd = self._info(), self._ctx_create(), self._gpu_alloc(16 << 20), 0,0
flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \
| kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)
self.ctx = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=flags).drawctxt_id
self.cmd_buf = self._gpu_alloc(16 << 20)
self.cmd_buf_allocator = BumpAllocator(size=self.cmd_buf.size, start=cast(int, self.cmd_buf.va_addr), wrap=True)
self.border_color_buf = self._gpu_alloc(0x1000, fill_zeroes=True)
self.last_cmd:int = 0
# Set max power
struct.pack_into('IIQQ', pwr:=memoryview(bytearray(0x18)), 0, 1, self.ctx, mv_address(_:=memoryview(array.array('I', [1]))), 4)
kgsl.IOCTL_KGSL_SETPROPERTY(self.fd, type=kgsl.KGSL_PROP_PWR_CONSTRAINT, value=mv_address(pwr), sizebytes=pwr.nbytes)
# Load info about qcom device
info = kgsl.struct_kgsl_devinfo()
kgsl.IOCTL_KGSL_DEVICE_GETPROPERTY(self.fd, type=kgsl.KGSL_PROP_DEVICE_INFO, value=ctypes.addressof(info), sizebytes=ctypes.sizeof(info))
QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF)
if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}")
super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self),
QCOMSignal, QCOMComputeQueue, None)
def _ctx_create(self):
cr = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=(kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT |
kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC | kgsl.KGSL_CONTEXT_PRIORITY(8) |
kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)))
# Set power to maximum.
struct.pack_into('IIQQ', pwr:=memoryview(bytearray(0x18)), 0, 1, cr.drawctxt_id, mv_address(_:=memoryview(array.array('I', [1]))), 4)
kgsl.IOCTL_KGSL_SETPROPERTY(self.fd, type=kgsl.KGSL_PROP_PWR_CONSTRAINT, value=mv_address(pwr), sizebytes=pwr.nbytes)
return cr.drawctxt_id
def _info(self):
info = kgsl.struct_kgsl_devinfo()
kgsl.IOCTL_KGSL_DEVICE_GETPROPERTY(self.fd, type=kgsl.KGSL_PROP_DEVICE_INFO, value=ctypes.addressof(info), sizebytes=ctypes.sizeof(info))
return info
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False):
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
if uncached: flags |= kgsl.KGSL_CACHEMODE(kgsl.KGSL_CACHEMODE_UNCACHED)
@@ -366,19 +362,11 @@ class QCOMDevice(HCQCompiled):
va_addr = libc.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, self.fd, alloc.id * 0x1000)
if fill_zeroes: ctypes.memset(va_addr, 0, size)
return QCOMBuffer(va_addr=va_addr, size=size, info=alloc)
return HCQBuffer(va_addr=va_addr, size=size, meta=alloc)
def _gpu_free(self, mem):
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.info.id)
libc.munmap(mem.va_addr, mem.info.mmapsize)
def _alloc_cmd_buf(self, sz: int):
self.cmd_buf_ptr = (cur_ptr:=self.cmd_buf_ptr if self.cmd_buf_ptr + sz < self.cmd_buf.size else 0) + sz
return self.cmd_buf.va_addr + cur_ptr
def _border_color_base(self):
if not hasattr(self, '_border_color_gpu'): self._border_color_gpu = self._gpu_alloc(0x1000, fill_zeroes=True)
return self._border_color_gpu.va_addr
def _gpu_free(self, mem:HCQBuffer):
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta.id)
libc.munmap(mem.va_addr, mem.meta.mmapsize)
def _ensure_stack_size(self, sz):
if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any
from typing import List, Optional, Dict, Tuple, cast, Type, Union, TypeVar, Generic, Any
import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv, round_up
from tinygrad.renderer import Renderer
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
from tinygrad.ops import sym_infer, sint, Variable
@@ -14,6 +14,15 @@ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
QueueType = TypeVar('QueueType', bound='HWQueue')
class BumpAllocator:
def __init__(self, size:int, start:int=0, wrap:bool=True): self.size, self.ptr, self.start_off, self.wrap = size, 0, start, wrap
def alloc(self, size:int, alignment:int=1) -> int:
if round_up(self.ptr, alignment) + size > self.size:
if not self.wrap: raise RuntimeError("Out of memory")
self.ptr = 0
self.ptr = (res:=round_up(self.ptr, alignment)) + size
return res + self.start_off
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
@@ -121,6 +130,9 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
Implementing this method is optional but recommended for performance gains.
"""
def bind_args_state(self, args_state:ArgsStateType):
for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None):
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
@@ -224,24 +236,20 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Optional[Type[HWQueue
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
class HCQArgsState(Generic[ProgramType]):
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[sint, ...]=()):
self.ptr, self.prg = ptr, prg
self.bind_data:List[Tuple[Tuple[sint, ...], int, str]] = []
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
class CLikeArgsState(HCQArgsState[ProgramType]):
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), prefix:Optional[List[int]]=None):
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[sint, ...]=(), prefix:Optional[List[int]]=None):
super().__init__(ptr, prg, bufs, vals=vals)
if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
self.bufs = to_mv(self.ptr + len(prefix or []) * 4, len(bufs) * 8).cast('Q')
self.vals = to_mv(self.ptr + len(prefix or []) * 4 + len(bufs) * 8, len(vals) * 4).cast('I')
self.bufs[:] = array.array('Q', [b.va_addr for b in bufs])
self.vals[:] = array.array('I', vals)
def update_buffer(self, index:int, buf:HCQBuffer): self.bufs[index] = buf.va_addr
def update_var(self, index:int, val:int): self.vals[index] = val
self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
class HCQProgram(Generic[DeviceType]):
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
@@ -257,7 +265,7 @@ class HCQProgram(Generic[DeviceType]):
Returns:
Arguments state with the given buffers and values set for the program.
"""
return self.args_state_t(kernargs_ptr or self.dev._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
return self.args_state_t(kernargs_ptr or self.dev.kernargs_alloctor.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
@@ -333,7 +341,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
@@ -349,7 +357,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
self.kernargs_ptr:int = self.kernargs_page.va_addr
self.kernargs_alloctor:BumpAllocator = BumpAllocator(self.kernargs_page.size, start=cast(int, self.kernargs_page.va_addr), wrap=True)
self.devices.append(self)
def synchronize(self):
@@ -363,14 +371,6 @@ class HCQCompiled(Compiled, Generic[SignalType]):
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _alloc_kernargs(self, alloc_size:int) -> int:
"""
Allocates space for arguments passed to the kernel.
"""
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
return res
def _ensure_shared_time_base(self):
if not self.gpu2cpu_compute_time_diff.is_nan(): return
@@ -445,12 +445,13 @@ class HCQCompiled(Compiled, Generic[SignalType]):
def _wrap_timeline_signal(self):
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
self.timeline_signal.value = 0
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
class HCQBuffer:
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:Optional[HCQBuffer]=None):
self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
class HCQAllocator(LRUAllocator, Generic[DeviceType]):
class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
"""
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
@@ -463,8 +464,12 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]):
self.b_timeline, self.b_next = [0] * len(self.b), 0
super().__init__()
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
def map(self, buf:HCQBuffer): pass
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
def _copyin(self, dest:HCQBuffer, src:memoryview):
assert self.dev.hw_copy_queue_t is not None
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
@@ -525,9 +530,3 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]):
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
dest_dev.timeline_value += 1
def map(self, buf:HCQBuffer): pass
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import functools, operator, itertools, math
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast
from typing import Tuple, List, Optional, Dict, Set, cast, Sequence
from tinygrad.dtype import dtypes
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
@@ -95,10 +95,11 @@ class View:
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
def __lt__(self, o:View): return self.t < o.t
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
"""(idx, valid)"""
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
iexpr = sint_to_uop(self.offset)
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
if m is not None:
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
@@ -267,11 +268,9 @@ class View:
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
def expand(self, new_shape: Tuple[sint, ...]) -> View:
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
if 0 in self.shape:
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
return View.create(new_shape)
# TODO: this resolve might be wrong
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
# NOTE: does not check multiple of symbolic shape
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
if 0 in self.shape: return View.create(new_shape)
# NOTE: can the mask ever be (0,0)?
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
@@ -299,16 +298,13 @@ class View:
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
if self.shape == new_shape: return self
# TODO: this resolve shouldn't be needed
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
if 0 in self.shape:
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
return View.create(new_shape)
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
# check for the same size
if (self_all_int := all_int(self.shape)):
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
if 0 in self.shape: return View.create(new_shape)
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
# after the asserts, it's okay to check contiguous
@@ -330,14 +326,12 @@ class View:
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
strides.append(new_stride)
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
if resolve(acc != merged_dim): break
else:
strides += [0,] * (len(new_shape) - len(strides))
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
if new_mask is not None:
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
if resolve(acc != merged_dim): return None
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
new_strides = (0,) * (len(new_shape) - len(strides)) + tuple(strides[::-1])
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
return None

View File

@@ -89,11 +89,12 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
return ret
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
def _align_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
# unsqueeze left to make every shape same length
max_dim = max(len(shape) for shape in shapes)
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:Tuple[int, ...]):
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
@@ -281,7 +282,7 @@ class Tensor(SimpleMathTrait):
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
def item(self) -> ConstType:
"""
@@ -366,14 +367,13 @@ class Tensor(SimpleMathTrait):
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
if axis is not None:
if axis < 0: axis += len(self.shape)
axis = self._resolve_dim(axis)
if splits is None:
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
sz = ceildiv(total, len(devices))
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
boundaries = tuple(itertools.accumulate(splits))
bounds = tuple(zip((0,) + boundaries, boundaries))
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
@@ -947,7 +947,8 @@ class Tensor(SimpleMathTrait):
print(t.expand(4, -1).numpy())
```
"""
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
return self._broadcast_to(new_shape)
def permute(self, order, *args) -> Tensor:
"""
@@ -1117,7 +1118,7 @@ class Tensor(SimpleMathTrait):
case list() | tuple() | Tensor():
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
index = (index < 0).where(size, 0) + index # treat negative index values
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
case int() | UOp(): # sint
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
@@ -1162,8 +1163,7 @@ class Tensor(SimpleMathTrait):
for dim, tensor in zip(dims, tensors):
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
a = Tensor.arange(x.shape[dim], device=self.device, requires_grad=False).reshape((x.shape[dim],) + (1,)*(x.ndim - dim - 1))
masks.append(i == a)
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
# reduce masks to 1 mask
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
@@ -1225,7 +1225,7 @@ class Tensor(SimpleMathTrait):
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
index = index.to(self.device)
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
@@ -1289,7 +1289,7 @@ class Tensor(SimpleMathTrait):
```
"""
repeats = argfix(repeats, *args)
base_shape = _pad_left(self.shape, repeats)[0]
base_shape = _align_left(self.shape, repeats)[0]
unsqueezed_shape = flatten([[1, s] for s in base_shape])
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
final_shape = [r*s for r,s in zip(repeats, base_shape)]
@@ -1885,7 +1885,7 @@ class Tensor(SimpleMathTrait):
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
```
"""
return (-self).argmax(axis=axis, keepdim=keepdim)
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
def rearrange(self, formula: str, **sizes) -> Tensor:
"""
@@ -2002,11 +2002,27 @@ class Tensor(SimpleMathTrait):
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
def _ceil_mode_padding2d(self,k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int],
p_:Union[Tuple[int, ...], int]) -> Sequence[int]:
(d_,s_,p_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_,p_)), self.shape[-len(k_):]
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
o_ = [ceildiv(i+2*p - (d*(k-1)+1), s) + 1 for i,d,k,s,p in zip(i_,d_,k_,s_,p_)]
pads = list(self._padding2d(p_, len(k_)))
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
# `s*(o-1) + (d*(k-1)+1) - (i+2*p)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
# `smax(s*(o-1) - (p+i-1), 0)` -> last_sliding_window_start - (left_pad + input_size - zero_offset)
for dim,(o,i,s,p,k,d) in enumerate(zip(o_,i_,s_,p_,k_,d_)): pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+2*p) - smax(s*(o-1) - (p+i-1), 0)
return pads
# NOTE: these work for more than 2D
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
"""
Applies average pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
When `count_include_pad` is set to False, zero padding will not be included in the averaging calculation.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/average-pooling
@@ -2016,17 +2032,30 @@ class Tensor(SimpleMathTrait):
print(t.avg_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
```
"""
padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
reg_pads, ceil_pads = self._padding2d(padding,len(k_)), self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding)
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
if not count_include_pad:
pads = ceil_pads if ceil_mode else reg_pads
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
if not ceil_mode: return pool(self, reg_pads).mean(axis)
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
"""
Applies max pooling over a tensor.
When `ceil_mode` is set to True, output shape will be determined using ceil division.
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
See: https://paperswithcode.com/method/max-pooling
@@ -2036,11 +2065,15 @@ class Tensor(SimpleMathTrait):
print(t.max_pool2d().numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(ceil_mode=True).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t.max_pool2d(padding=1).numpy())
```
"""
padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
k_ = make_tuple(kernel_size, 2)
pads = self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding) if ceil_mode else self._padding2d(padding, len(k_))
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
@@ -2336,10 +2369,14 @@ class Tensor(SimpleMathTrait):
index, dim = index.to(self.device), self._resolve_dim(dim)
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
assert all((d == dim or se >= ind) and sr >= ind for d,(se,ind,sr) in enumerate(zip(self.shape, index.shape, src.shape))), \
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
mask = (index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)).transpose(-1, dim)
src = src.unsqueeze(-1).expand((None,)*src.ndim + (self.shape[dim],)).transpose(-1, dim).shrink(tuple((0,s) for s in mask.shape))
# shrink src to index shape to shrink away the unused values
src = src.shrink(tuple((0,s) for s in index.shape))
# prepare src and mask for reduce with respect to dim
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
@@ -2928,15 +2965,15 @@ class Tensor(SimpleMathTrait):
return self / (1 + self.abs())
# ***** broadcasted elementwise ops *****
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
if self.shape == shape: return self
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
padded, _ = _pad_left(self.shape, shape)
# for each dimension, check either from_ is 1, or it does not change
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
return F.Expand.apply(self.reshape(padded), shape=shape)
def _broadcast_to(self, new_shape:Tuple[sint, ...]) -> Tensor:
if self.shape == new_shape: return self
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
shape, _ = _align_left(self.shape, new_shape)
# for each dimension, check either dim is 1, or it does not change
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
return F.Expand.apply(self.reshape(shape), shape=new_shape)
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
x: Tensor = self
@@ -2955,11 +2992,10 @@ class Tensor(SimpleMathTrait):
if reverse: x, y = y, x
# broadcast
out_shape = _broadcast_shape(x.shape, y.shape)
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
@@ -3116,7 +3152,7 @@ class Tensor(SimpleMathTrait):
```
"""
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
def lshift(self, x:int):
"""
@@ -3160,8 +3196,9 @@ class Tensor(SimpleMathTrait):
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse:
# simple pow identities
if x < 0: return self.reciprocal().pow(-x)
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
if x == 0: return 1 + self * 0
# rewrite pow 0.5 to sqrt
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
@@ -3178,7 +3215,8 @@ class Tensor(SimpleMathTrait):
# inject nan for negative base and non-integer exponent
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
"""
@@ -3354,6 +3392,11 @@ class Tensor(SimpleMathTrait):
if not Tensor.training or p == 0: return self
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
# helper function commonly used for indexing
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
offset = self.ndim - self._resolve_dim(dim) - 1
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
def one_hot(self, num_classes:int=-1) -> Tensor:
"""
Converts `self` to a one-hot tensor.
@@ -3366,7 +3409,7 @@ class Tensor(SimpleMathTrait):
```
"""
if num_classes == -1: num_classes = (self.max()+1).item()
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
@@ -3442,8 +3485,8 @@ class Tensor(SimpleMathTrait):
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
@@ -3590,8 +3633,15 @@ class Tensor(SimpleMathTrait):
t = t.cast(dtypes.int32)
print(t.dtype, t.numpy())
```
```python exec="true" source="above" session="tensor" result="python"
t = t.cast(dtypes.uint8)
print(t.dtype, t.numpy())
```
"""
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
def bitcast(self, dtype:DTypeLike) -> Tensor:
"""