mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
16
.github/workflows/benchmark.yml
vendored
16
.github/workflows/benchmark.yml
vendored
@@ -197,11 +197,11 @@ jobs:
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Run LLaMA-2 70B
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
@@ -380,11 +380,11 @@ jobs:
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Run LLaMA-2 70B
|
||||
run: AMD=1 RUN_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
@@ -508,10 +508,6 @@ jobs:
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: openpilot compile 0.9.4
|
||||
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py | tee openpilot_compile_0_9_4.txt
|
||||
- name: openpilot compile 0.9.7
|
||||
run: PYTHONPATH=. NOLOCALS=1 FLOAT16=1 IMAGE=2 QCOM=1 taskset -c 4-7 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_compile_0_9_7.txt
|
||||
- name: validate openpilot 0.9.7
|
||||
run: PYTHONPATH=. FLOAT16=0 IMAGE=2 QCOM=1 taskset -c 4-7 python3 test/external/external_benchmark_openpilot.py https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx | tee openpilot_image_0_9_7.txt
|
||||
- name: benchmark openpilot 0.9.4
|
||||
|
||||
23
.github/workflows/test.yml
vendored
23
.github/workflows/test.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: Unit Tests
|
||||
env:
|
||||
# increment this when downloads substantially change to avoid the internet
|
||||
DOWNLOAD_CACHE_VERSION: '7'
|
||||
DOWNLOAD_CACHE_VERSION: '8'
|
||||
RUN_PROCESS_REPLAY: 1
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PYTHONPATH: .
|
||||
@@ -293,22 +293,15 @@ jobs:
|
||||
PYTHONPATH="." GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py --durations=20
|
||||
PYTHONPATH="." GPU=1 IMAGE=2 python3 test/models/test_end2end.py TestEnd2End.test_linear_mnist
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model compile and size
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." DEBUG=2 ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=1 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
python -c 'import os; assert os.path.getsize("/tmp/output.thneed") < 100_000_000'
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot compile3
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_GATED_READ_IMAGE=13 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot fastvits model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile2.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test ONNX (GPU)
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
@@ -387,7 +380,7 @@ jobs:
|
||||
WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \
|
||||
test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \
|
||||
test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \
|
||||
--durations=20
|
||||
test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20
|
||||
- name: Run process replay tests
|
||||
run: |
|
||||
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
|
||||
@@ -439,7 +432,7 @@ jobs:
|
||||
- name: Test Beam Search
|
||||
run: PYTHONPATH="." METAL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
- name: Fuzz Test linearizer
|
||||
run: PYTHONPATH="." METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=24 FUZZ_MAX_SIZE=1000000 python test/external/fuzz_linearizer.py
|
||||
run: PYTHONPATH="." METAL=1 DEPTH=4 FUZZ_N=50 FUZZ_MAX_SIZE=1000000 python test/external/fuzz_linearizer.py
|
||||
# - name: Fuzz Test models schedule
|
||||
# run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
|
||||
- name: Run TRANSCENDENTAL math
|
||||
@@ -528,7 +521,7 @@ jobs:
|
||||
if: matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv'
|
||||
run: |
|
||||
cd ${{ github.workspace }}/gpuocelot/ocelot/build
|
||||
sudo ninja install -d explain
|
||||
sudo cp libgpuocelot.so /usr/lib/libgpuocelot.so
|
||||
- name: Install packages (amd)
|
||||
if: matrix.backend == 'amd'
|
||||
run: |
|
||||
|
||||
@@ -220,7 +220,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--download_model", action="store_true", help="Download a model")
|
||||
parser.add_argument("--model", type=Path, help="Model path")
|
||||
parser.add_argument("--size", choices=["1B", "8B", "70B"], default="8B", help="Model size")
|
||||
parser.add_argument("--size", choices=["1B", "8B", "70B"], default="1B", help="Model size")
|
||||
parser.add_argument("--shard", type=int, default=1, help="Shard the model across multiple devices")
|
||||
parser.add_argument("--quantize", choices=["int8", "nf4", "float16"], help="Quantization method")
|
||||
parser.add_argument("--no_api", action="store_true", help="Disable the api and run a cli test interface")
|
||||
@@ -234,8 +234,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--profile", action="store_true", help="Output profile data")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (args.model and not args.download_model) or (not args.model and args.download_model), "either download or provide model"
|
||||
if args.download_model:
|
||||
# download_model is the default without a model passed in
|
||||
if args.download_model or not args.model:
|
||||
if args.size == "1B":
|
||||
fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model", "tokenizer.model", subdir="llama3-1b-instruct")
|
||||
args.model = fetch("https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", "Llama-3.2-1B-Instruct-Q6_K.gguf", subdir="llama3-1b-instruct")
|
||||
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import os, sys, io, pathlib, json, struct
|
||||
import numpy as np
|
||||
sys.path.insert(0, str(pathlib.Path(__file__).parents[1]))
|
||||
|
||||
if "FLOAT16" not in os.environ: os.environ["FLOAT16"] = "1"
|
||||
if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
|
||||
OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx"
|
||||
|
||||
import onnx
|
||||
from typing import Tuple, List, Optional, Dict, cast
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
# load the model
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
|
||||
# run the model
|
||||
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
|
||||
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
|
||||
schedule = create_schedule([ret.lazydata])
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
for out in si.outputs: depends.add(out)
|
||||
|
||||
# run all kernels that don't depend on the inputs
|
||||
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
|
||||
schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op is Ops.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
|
||||
import onnx
|
||||
#import pyopencl as cl
|
||||
#from extra.thneed import Thneed
|
||||
import numpy as np
|
||||
onnx_model = onnx.load(io.BytesIO(onnx_data))
|
||||
|
||||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
Tensor.manual_seed(1337)
|
||||
new_inputs = {k:Tensor.randn(*shp, requires_grad=False)*8 for k,shp in input_shapes.items()}
|
||||
new_np_inputs = {k:v.realize().numpy() for k,v in new_inputs.items()}
|
||||
|
||||
if getenv("ORT"):
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_data)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_np_inputs.items()})
|
||||
new_torch_out = onnx_output[0]
|
||||
print("got ort outputs")
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy()
|
||||
print("got torch outputs")
|
||||
|
||||
# if you don't have a schedule
|
||||
if eis is None:
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
new_tinygrad_out = next(iter(run_onnx(new_inputs).values())).cast(dtypes.float32).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("classic self-test passed!")
|
||||
return
|
||||
|
||||
# set inputs
|
||||
for k,v in inputs.items(): v.lazydata.base.realized.copyin(new_np_inputs[k].data)
|
||||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
output = eis[-1].bufs[0]
|
||||
for ei in eis: ei.run()
|
||||
|
||||
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=_to_np_dtype(output.dtype))
|
||||
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes()
|
||||
|
||||
# quick test for ONNX issues
|
||||
#thneed_test_onnx(onnx_data, None)
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is Ops.SINK)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||
schedule = memory_planner(schedule)
|
||||
for si in schedule:
|
||||
for b in si.outputs:
|
||||
assert not b.is_allocated(), "output should not be allocated"
|
||||
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
|
||||
print(f"**** compiling real kernels {image_count}/{len(schedule)} images ****")
|
||||
eis = list(tqdm(lower_schedule(schedule), total=len(schedule)))
|
||||
|
||||
print("kernel count:", len(eis))
|
||||
assert len(eis) <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
|
||||
|
||||
# new simple thneed
|
||||
def to_ref(b:Buffer): return struct.pack("Q", id(b)).decode("latin_1")
|
||||
|
||||
seen_buffers = set()
|
||||
input_buffers = [x.lazydata.buffer for x in inputs.values()]
|
||||
jdat = {"binaries": [], "programs": {}, "kernels": [], "objects": []}
|
||||
jdat["inputs"] = {k:to_ref(v.lazydata.buffer) for k,v in inputs.items()}
|
||||
jdat["outputs"] = [to_ref(eis[-1].bufs[0])]
|
||||
weights = []
|
||||
for i,ei in enumerate(eis):
|
||||
#print("***", i)
|
||||
for b in ei.bufs:
|
||||
needs_load = b.is_allocated() and b not in input_buffers
|
||||
#print(b, needs_load)
|
||||
if b in seen_buffers: continue
|
||||
seen_buffers.add(b)
|
||||
if isinstance(b.dtype, ImageDType):
|
||||
base_dtype = dtypes.float16 if b.dtype.fmt == 'e' else dtypes.float32
|
||||
row_pitch = (b.dtype.shape[0]*4*base_dtype.itemsize + 63)//64 * 64
|
||||
size = row_pitch * b.dtype.shape[1]
|
||||
jdat['objects'].append({
|
||||
"id": to_ref(b), "needs_load": needs_load, "size": size, "arg_type": "image2d_t",
|
||||
"width": b.dtype.shape[0], "height": b.dtype.shape[1], "row_pitch": row_pitch, "float32": b.dtype.base == dtypes.float32,
|
||||
})
|
||||
if needs_load:
|
||||
t = Tensor.empty(b.dtype.shape, dtype=b.dtype)
|
||||
t.lazydata.buffer = b
|
||||
data = t.cast(dtypes.float32).pad(((0, row_pitch//(4*base_dtype.itemsize)-b.dtype.shape[0]), (0,0), (0,0))).contiguous().numpy()
|
||||
# NOTE: this cast must be done in numpy for platforms that don't support half
|
||||
if base_dtype == dtypes.float16: data = data.astype(np.float16)
|
||||
weights.append(data.tobytes())
|
||||
assert len(weights[-1]) == size, "wrong size buffer"
|
||||
else:
|
||||
jdat['objects'].append({
|
||||
"id": to_ref(b), "arg_type": b.dtype.name + "*", "needs_load": needs_load, "size": b.nbytes,
|
||||
})
|
||||
if needs_load:
|
||||
weights.append(b.as_buffer())
|
||||
assert len(weights[-1]) == b.nbytes, "wrong size buffer"
|
||||
|
||||
saved_binaries = set()
|
||||
binaries = []
|
||||
gated_read_image_count = 0
|
||||
GlobalCounters.reset()
|
||||
with Context(DEBUG=max(DEBUG.value, 2)):
|
||||
for ei in eis:
|
||||
prg = cast(CompiledRunner, ei.prg)
|
||||
assert len(prg.p.vars) == 0
|
||||
if prg.p.function_name not in saved_binaries:
|
||||
jdat['binaries'].append({"name":prg.p.function_name, "length":len(prg.lib)})
|
||||
binaries.append(prg.lib)
|
||||
saved_binaries.add(prg.p.function_name)
|
||||
gated_read_image_count += prg.p.src.count("?read_image")
|
||||
ei.run()
|
||||
jdat['kernels'].append({
|
||||
"name": prg.p.function_name,
|
||||
"work_dim": len(prg.p.global_size),
|
||||
"global_work_size": prg.p.global_size,
|
||||
"local_work_size": prg.p.local_size,
|
||||
"num_args": len(ei.bufs),
|
||||
"args": [to_ref(b) for b in ei.bufs],
|
||||
"arg_size": [8]*len(ei.bufs),
|
||||
})
|
||||
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
|
||||
assert gated_read_image_count <= allowed_gated_read_image, \
|
||||
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
|
||||
|
||||
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
|
||||
print(f"saving thneed to {output_fn} with {len(weights)} buffers and {len(binaries)} binaries")
|
||||
with open(output_fn, "wb") as f:
|
||||
j = json.dumps(jdat, ensure_ascii=False).encode('latin_1')
|
||||
f.write(struct.pack("I", len(j)))
|
||||
f.write(j)
|
||||
for w in weights: f.write(w)
|
||||
for b in binaries: f.write(b)
|
||||
print("saved", f.tell(), "bytes")
|
||||
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
if FLOAT16 == 0:
|
||||
try:
|
||||
test_vs_onnx(onnx_data, eis, inputs)
|
||||
except ModuleNotFoundError as e:
|
||||
print(f"TEST NOT HAPPENING {e}")
|
||||
|
||||
|
||||
@@ -5,9 +5,10 @@ if "IMAGE" not in os.environ: os.environ["IMAGE"] = "2"
|
||||
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
|
||||
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"
|
||||
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
|
||||
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from tinygrad.tensor import _from_np_dtype
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
import onnx
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
@@ -16,12 +17,11 @@ from extra.onnx import get_run_onnx # TODO: port to main tinygrad
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = "/tmp/openpilot.pkl"
|
||||
|
||||
def compile():
|
||||
def compile(onnx_file):
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
Tensor.no_grad = True
|
||||
Tensor.training = False
|
||||
|
||||
onnx_bytes = fetch(OPENPILOT_MODEL)
|
||||
onnx_model = onnx.load(onnx_bytes)
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
print("loaded model")
|
||||
|
||||
@@ -30,51 +30,103 @@ def compile():
|
||||
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
print("created tensors")
|
||||
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs:
|
||||
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
|
||||
for i in range(3):
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
|
||||
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
|
||||
ret = run_onnx_jit(**inputs).numpy()
|
||||
# copy i == 1 so use of JITBEAM is okay
|
||||
if i == 1: test_val = np.copy(ret)
|
||||
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
|
||||
np.testing.assert_equal(test_val, ret)
|
||||
np.testing.assert_equal(test_val, ret, "JIT run failed")
|
||||
print("jit run validated")
|
||||
|
||||
# checks from compile2
|
||||
kernel_count = 0
|
||||
gated_read_image_count = 0
|
||||
for ei in run_onnx_jit.captured.jit_cache:
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
kernel_count += 1
|
||||
gated_read_image_count += ei.prg.p.src.count("?read_image")
|
||||
print(f"kernel_count: {kernel_count} gated_read_image_count: {gated_read_image_count}")
|
||||
assert kernel_count <= getenv("ALLOWED_KERNEL_COUNT", 0) or getenv("ALLOWED_KERNEL_COUNT", 0) == 0, "too many kernels!"
|
||||
if (allowed_gated_read_image:=getenv("ALLOWED_GATED_READ_IMAGE", -1)) != -1:
|
||||
assert gated_read_image_count <= allowed_gated_read_image, \
|
||||
f"too many gated read_image! {gated_read_image_count=}, {allowed_gated_read_image=}"
|
||||
|
||||
with open(OUTPUT, "wb") as f:
|
||||
pickle.dump(run_onnx_jit, f)
|
||||
mdl_sz = os.path.getsize(onnx_bytes)
|
||||
mdl_sz = os.path.getsize(onnx_file)
|
||||
pkl_sz = os.path.getsize(OUTPUT)
|
||||
print(f"mdl size is {mdl_sz/1e6:.2f}M")
|
||||
print(f"pkl size is {pkl_sz/1e6:.2f}M")
|
||||
print("**** compile done ****")
|
||||
return test_val
|
||||
|
||||
def test(test_val=None):
|
||||
with open(OUTPUT, "rb") as f:
|
||||
run = pickle.load(f)
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
|
||||
sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))}
|
||||
def test_vs_compile(run, new_inputs, test_val=None):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
|
||||
# create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory)
|
||||
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
|
||||
# run 20 times
|
||||
for _ in range(20):
|
||||
st = time.perf_counter()
|
||||
# Need to cast non-image inputs from numpy, this is only realistic way to run it
|
||||
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
|
||||
**{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}}
|
||||
out = run(**inputs)
|
||||
mt = time.perf_counter()
|
||||
val = out['outputs'].numpy()
|
||||
val = out.numpy()
|
||||
et = time.perf_counter()
|
||||
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
|
||||
print(out, val.shape, val.dtype)
|
||||
if test_val is not None: np.testing.assert_equal(test_val, val)
|
||||
print("**** test done ****")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_val = compile() if not getenv("RUN") else None
|
||||
test(test_val)
|
||||
# test that changing the numpy changes the model outputs
|
||||
for v in new_inputs_numpy.values(): v *= 2
|
||||
out = run(**inputs)
|
||||
changed_val = out.numpy()
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val)
|
||||
return val
|
||||
|
||||
def test_vs_onnx(new_inputs, test_val, onnx_file):
|
||||
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
if getenv("ORT"):
|
||||
# test with onnxruntime
|
||||
import onnxruntime as ort
|
||||
onnx_session = ort.InferenceSession(onnx_file)
|
||||
onnx_output = onnx_session.run([onnx_model.graph.output[0].name], {k:v.astype(np.float16) for k,v in new_inputs_numpy.items()})
|
||||
new_torch_out = onnx_output[0]
|
||||
print("got ort outputs")
|
||||
else:
|
||||
# test with torch
|
||||
from test.models.test_onnx import run_onnx_torch
|
||||
# NOTE: we have to correct the order here
|
||||
new_torch_out = run_onnx_torch(onnx_model, {k.name:new_inputs_numpy[k.name] for k in onnx_model.graph.input}).numpy()
|
||||
print("got torch outputs")
|
||||
|
||||
np.testing.assert_allclose(new_torch_out.reshape(test_val.shape), test_val, atol=1e-4, rtol=1e-2)
|
||||
print("test vs onnx passed")
|
||||
|
||||
if __name__ == "__main__":
|
||||
onnx_file = fetch(OPENPILOT_MODEL)
|
||||
test_val = compile(onnx_file) if not getenv("RUN") else None
|
||||
|
||||
with open(OUTPUT, "rb") as f: pickle_loaded = pickle.load(f)
|
||||
|
||||
# same randomness as compile
|
||||
Tensor.manual_seed(100)
|
||||
new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in
|
||||
sorted(zip(pickle_loaded.captured.expected_names, pickle_loaded.captured.expected_st_vars_dtype_device))}
|
||||
|
||||
test_val = test_vs_compile(pickle_loaded, new_inputs, test_val)
|
||||
if not getenv("FLOAT16"): test_vs_onnx(new_inputs, test_val, onnx_file)
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
#!/bin/bash
|
||||
NOLOCALS=1 FLOAT16=1 DEBUGCL=1 IMAGE=2 GPU=1 python3 examples/openpilot/compile2.py
|
||||
35
examples/self_tokenize.py
Normal file
35
examples/self_tokenize.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import os, pathlib
|
||||
from examples.llama3 import Tokenizer
|
||||
from tabulate import tabulate
|
||||
from tinygrad import fetch
|
||||
from tinygrad.helpers import flatten
|
||||
|
||||
# llama 3 tokenizer
|
||||
tokenizer = Tokenizer(fetch("https://huggingface.co/bofenghuang/Meta-Llama-3-8B/resolve/main/original/tokenizer.model").as_posix())
|
||||
|
||||
def read_code(base_path):
|
||||
ret = []
|
||||
for path, _, files in os.walk(os.path.join(base_path, "tinygrad")):
|
||||
for name in files:
|
||||
if not name.endswith(".py"): continue
|
||||
if 'tinygrad/runtime/autogen' in path.replace('\\', '/'): continue
|
||||
fullpath = os.path.join(path, name)
|
||||
code = pathlib.Path(fullpath).read_text()
|
||||
ret += [(fullpath.split("tinygrad/", 1)[1], code)]
|
||||
return ret
|
||||
|
||||
if __name__ == "__main__":
|
||||
ret = read_code(".")
|
||||
|
||||
table = []
|
||||
for name,code in ret:
|
||||
table.append([name, len(tokenizer.encode(name+"\x00"+code))])
|
||||
print(tabulate([["name", "llm tokens"]]+sorted(table, key=lambda x: -x[1]), headers="firstrow"))
|
||||
|
||||
code_str = '\x00'.join(flatten(ret))
|
||||
print(f"code has {len(code_str)} chars")
|
||||
newline_count = code_str.count('\n')
|
||||
print(f"code has {newline_count} newlines")
|
||||
|
||||
encoded = tokenizer.encode(code_str)
|
||||
print(f"code has {len(encoded)} tokens")
|
||||
@@ -17,7 +17,7 @@ canvas { display: none; }
|
||||
* { text-align: center; font-family: monospace; }
|
||||
</style>
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
<script src="../../net.js"></script>
|
||||
<link rel="icon" type="image/x-icon" href="https://raw.githubusercontent.com/tinygrad/tinygrad/master/docs/logo.png">
|
||||
</head>
|
||||
<body>
|
||||
@@ -61,7 +61,7 @@ canvas { display: none; }
|
||||
|
||||
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
|
||||
|
||||
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
|
||||
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("../../net.safetensors")).arrayBuffer());
|
||||
|
||||
const reorderChannelsAndRemoveAlpha = (data) => {
|
||||
const out = [];
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
from extra.export_model import compile_net, jit_model
|
||||
from extra.export_model import compile_net, jit_model, dtype_to_js_type
|
||||
from extra.f16_decompress import u32_to_f16
|
||||
from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Device
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.helpers import fetch
|
||||
from typing import NamedTuple, Any, List
|
||||
import requests
|
||||
@@ -29,7 +30,7 @@ def convert_f32_to_f16(input_file, output_file):
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
_, data_start, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
chunk_size = 536870912
|
||||
|
||||
@@ -51,12 +52,12 @@ def split_safetensor(fn):
|
||||
part_offset = offset - last_offset
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
part_end_offsets.append(data_start+offset)
|
||||
last_offset = offset
|
||||
|
||||
text_model_start = int(text_model_offset/2)
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
part_end_offsets.append(text_model_start+data_start)
|
||||
cur_pos = 0
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
@@ -65,7 +66,7 @@ def split_safetensor(fn):
|
||||
cur_pos = end_pos
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), f'./net_textmodel.safetensors'), "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
f.write(net_bytes[text_model_start+data_start:])
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
@@ -95,7 +96,8 @@ if __name__ == "__main__":
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode),
|
||||
Step(name = "f16tof32", input = [Tensor.randn(2097120, dtype=dtypes.uint32)], forward = u32_to_f16)
|
||||
]
|
||||
|
||||
prg = ""
|
||||
@@ -116,19 +118,23 @@ if __name__ == "__main__":
|
||||
weights = {id(x.lazydata.base.realized): name for name, x in state.items()}
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{fixup_code(code, key)}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _, _, _) in statements])
|
||||
input_names = [name for _,name in special_names.items() if "input" in name]
|
||||
output_names = [name for _,name in special_names.items() if "output" in name]
|
||||
input_buf_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
|
||||
output_buf_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
exported_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weights else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weights[_key]}'], '{weights[_key]}'))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buf_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,_ in enumerate(input_names)])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
{kernel_code}
|
||||
|
||||
return {{
|
||||
"setup": async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor[0]);
|
||||
const metadata = safetensor ? getTensorMetadata(safetensor[0]) : null;
|
||||
|
||||
{bufs}
|
||||
{exported_bufs}
|
||||
|
||||
{gpu_write_bufs}
|
||||
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
@@ -147,8 +153,8 @@ if __name__ == "__main__":
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size/4);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
const resultBuffer = new {output_buf_types[0]}(gpuReadBuffer.size/{bufs[output_names[0]][1].itemsize});
|
||||
resultBuffer.set(new {output_buf_types[0]}(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
|
||||
@@ -165,10 +165,6 @@
|
||||
import ClipTokenizer from './clip_tokenizer.js';
|
||||
window.clipTokenizer = new ClipTokenizer();
|
||||
</script>
|
||||
<script type="module">
|
||||
import { f16tof32GPU } from 'https://unpkg.com/f16-to-f32-gpu@0.1.0/src/index.js';
|
||||
window.f16tof32GPU = f16tof32GPU;
|
||||
</script>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
@@ -214,6 +210,8 @@
|
||||
<canvas id="canvas" width="512" height="512"></canvas>
|
||||
|
||||
<script>
|
||||
let f16decomp = null;
|
||||
|
||||
function initDb() {
|
||||
return new Promise((resolve, reject) => {
|
||||
let db;
|
||||
@@ -416,7 +414,7 @@
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));
|
||||
|
||||
const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
|
||||
const decodeChunkSize = 67107840;
|
||||
const decodeChunkSize = 8388480;
|
||||
const numChunks = Math.ceil(allToDecomp/decodeChunkSize);
|
||||
|
||||
console.log(allToDecomp + " bytes to decompress");
|
||||
@@ -440,7 +438,8 @@
|
||||
let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
|
||||
let chunkEndF16 = chunkStartF16 + decodeChunkSize;
|
||||
let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
|
||||
let result = await f16tof32GPU(chunk);
|
||||
let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
|
||||
let result = await f16decomp(uint32Chunk);
|
||||
let resultUint8 = new Uint8Array(result.buffer);
|
||||
let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
|
||||
let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
|
||||
@@ -483,6 +482,7 @@
|
||||
}
|
||||
|
||||
const device = await getDevice();
|
||||
f16decomp = await f16tof32().setup(device, safetensorParts),
|
||||
safetensorParts = await getAndDecompressF16Safetensors(device, progress);
|
||||
|
||||
modelDlTitle.innerHTML = "Compiling model"
|
||||
|
||||
@@ -12,7 +12,7 @@ if __name__ == "__main__":
|
||||
yolo_infer = YOLOv8(w=0.25, r=2.0, d=0.33, num_classes=80)
|
||||
state_dict = safe_load(get_weights_location(yolo_variant))
|
||||
load_state_dict(yolo_infer, state_dict)
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,256,256))
|
||||
prg, inp_sizes, out_sizes, state = export_model(yolo_infer, Device.DEFAULT.lower(), Tensor.randn(1,3,416,416))
|
||||
dirname = Path(__file__).parent
|
||||
safe_save(state, (dirname / "net.safetensors").as_posix())
|
||||
with open(dirname / f"net.js", "w") as text_file:
|
||||
|
||||
@@ -95,6 +95,7 @@
|
||||
</head>
|
||||
<body>
|
||||
<h2>YOLOv8 tinygrad WebGPU</h2>
|
||||
<h2 id="wgpu-error" style="display: none; color: red;">Error: WebGPU is not supported in this browser</h2>
|
||||
<div class="video-container">
|
||||
<video id="video" muted autoplay playsinline></video>
|
||||
<canvas id="canvas"></canvas>
|
||||
@@ -107,7 +108,7 @@
|
||||
</div>
|
||||
<script>
|
||||
let net = null;
|
||||
const modelInputSize = 256
|
||||
const modelInputSize = 416;
|
||||
let lastCalledTime;
|
||||
let fps = 0, accumFps = 0, frameCounter = 0;
|
||||
|
||||
@@ -117,6 +118,7 @@
|
||||
const offscreenCanvas = document.createElement('canvas');
|
||||
const fpsMeter = document.getElementById('fps-meter');
|
||||
const loadingContainer = document.getElementById('div-loading');
|
||||
const wgpuError = document.getElementById('wgpu-error');
|
||||
offscreenCanvas.width = modelInputSize;
|
||||
offscreenCanvas.height = modelInputSize;
|
||||
const offscreenContext = offscreenCanvas.getContext('2d');
|
||||
@@ -147,7 +149,7 @@
|
||||
lastCalledTime = now;
|
||||
accumFps += 1/delta;
|
||||
|
||||
if (frameCounter++ >= 30) {
|
||||
if (frameCounter++ >= 10) {
|
||||
fps = accumFps/frameCounter;
|
||||
frameCounter = 0;
|
||||
accumFps = 0;
|
||||
@@ -206,7 +208,12 @@
|
||||
|
||||
async function detectObjectsOnFrame(offscreenContext) {
|
||||
if (!net) {
|
||||
net = await loadNet(await getDevice());
|
||||
let device = await getDevice();
|
||||
if (!device) {
|
||||
wgpuError.style.display = "block";
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
net = await loadNet(device);
|
||||
loadingContainer.style.display = "none";
|
||||
}
|
||||
let start = performance.now();
|
||||
@@ -239,7 +246,7 @@
|
||||
}
|
||||
|
||||
const getDevice = async () => {
|
||||
if (!navigator.gpu) error("WebGPU not supported.");
|
||||
if (!navigator.gpu) return false;
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
return await adapter.requestDevice();
|
||||
};
|
||||
|
||||
Binary file not shown.
@@ -79,6 +79,9 @@ def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,in
|
||||
cprog += [f"void net({inputs}, {outputs}) {{"] + [f"{name}({', '.join(args)});" for (name, args, _global_size, _local_size) in statements] + ["}"]
|
||||
return '\n'.join(cprog)
|
||||
|
||||
def dtype_to_js_type(dtype: DType) -> str:
|
||||
return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array"
|
||||
|
||||
def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names, input_names, output_names) -> Tuple[str,int,int]:
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size, _local_size) in statements])
|
||||
@@ -92,10 +95,12 @@ def export_model_webgpu(functions, statements, bufs, bufs_to_save, weight_names,
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], {global_size});" for i, (_name, args, global_size, _local_size) in enumerate(statements) ])
|
||||
_bufs = '\n '.join([f"const {name} = " + (f"createEmptyBuf(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, getTensorBuffer(safetensor, metadata['{weight_names[_key]}']))") + ";" for name,(size,dtype,_key) in bufs.items()])
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)])
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names]
|
||||
output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names]
|
||||
input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)])
|
||||
gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)])
|
||||
outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new Float32Array(gpuReadBuffer{i}.size/4);\n resultBuffer{i}.set(new Float32Array(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))])
|
||||
output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))]))
|
||||
return f"""
|
||||
{web_utils["getTensorBuffer"]}
|
||||
|
||||
16
extra/f16_decompress.py
Normal file
16
extra/f16_decompress.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from tinygrad import Tensor
|
||||
|
||||
def bit_extract(x: Tensor, e: int, s: int) -> Tensor:
|
||||
mask = (1 << (e - s + 1)) - 1
|
||||
return (x >> s) & mask
|
||||
|
||||
def u16_to_f16(x: Tensor) -> Tensor:
|
||||
sign = bit_extract(x, 15, 15).float()
|
||||
exponent = bit_extract(x, 14, 10).float()
|
||||
fraction = bit_extract(x, 9, 0).float()
|
||||
return sign.where(-1, 1) * exponent.where((exponent - 15.0).exp2() * (1 + fraction / 1024.0), 6.103515625e-5 * (fraction / 1024.0))
|
||||
|
||||
def u32_to_f16(oo: Tensor) -> Tensor:
|
||||
f1 = u16_to_f16(oo>>16)
|
||||
f2 = u16_to_f16(oo&0xFFFF)
|
||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||
@@ -1,40 +0,0 @@
|
||||
import numpy as np
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
|
||||
# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul
|
||||
def bit_extract(x, s, e) -> Tensor:
|
||||
# extract the top bits we don't want
|
||||
top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1))
|
||||
x = (x - top_bits) / (1<<e)
|
||||
return x.contiguous()
|
||||
|
||||
def u16_to_f16(x):
|
||||
sign = bit_extract(x, 15, 15).float()
|
||||
exponent = bit_extract(x, 14, 10).float()
|
||||
fraction = bit_extract(x, 9, 0).float()
|
||||
return sign.where(-1, 1) * exponent.where((exponent - 15).exp2() * (1 + fraction / 0x400), 6.103515625e-5 * (fraction / 0x400))
|
||||
|
||||
def u32_to_f16(oo):
|
||||
oo1 = (oo/0x10000).floor().contiguous()
|
||||
# TODO: this is wrong and unextractable until we do this math in u32
|
||||
oo2 = (oo-(oo1*0x10000)).floor().contiguous()
|
||||
f1 = u16_to_f16(oo1)
|
||||
f2 = u16_to_f16(oo2)
|
||||
return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# random float16
|
||||
Tensor.manual_seed(2)
|
||||
a = Tensor.randn(100, dtype=dtypes.float16)
|
||||
|
||||
# this converts it to u32 on disk
|
||||
oo = a.to("disk:/tmp/f16").cast(dtypes.uint32)[:50].to(Device.DEFAULT).realize()
|
||||
|
||||
# convert to 2xf16 using tinygrad math ops
|
||||
f16 = u32_to_f16(oo)
|
||||
|
||||
ref = a.numpy()
|
||||
out = f16.numpy().astype(np.float16)
|
||||
print(ref-out)
|
||||
|
||||
np.testing.assert_allclose(ref, out)
|
||||
266
extra/onnx.py
266
extra/onnx.py
@@ -1,14 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Union
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
from typing import List, Dict, Union, Callable, Any, Sequence
|
||||
import importlib, functools
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, OSX
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, all_same
|
||||
from tinygrad.dtype import DType, ConstType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, ValueInfoProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
except ImportError:
|
||||
@@ -17,178 +15,134 @@ except ImportError:
|
||||
def tensor_dtype_to_np_dtype(tensor_dtype:int) -> np.dtype: return TENSOR_TYPE_TO_NP_TYPE[tensor_dtype]
|
||||
|
||||
cache_misses = 0
|
||||
@lru_cache(None)
|
||||
def _cached_to_python_const(t:Tensor, tobytes): return t.data().tobytes() if tobytes else t.tolist()
|
||||
@functools.lru_cache(None)
|
||||
def _cached_to_python_const(t:Tensor):
|
||||
if t.dtype is dtypes.uint8: return t.data().tobytes()
|
||||
if 0 in t.shape: return []
|
||||
return t.tolist()
|
||||
|
||||
# Tensor -> python value cache for parameters
|
||||
def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
|
||||
def to_python_const(t) -> Union[List[ConstType], List[bytes], Union[ConstType, bytes]]:
|
||||
if not isinstance(t, Tensor): return t
|
||||
global cache_misses
|
||||
ret = _cached_to_python_const(t, tobytes)
|
||||
ret = _cached_to_python_const(t)
|
||||
if (info := _cached_to_python_const.cache_info()).misses > cache_misses and DEBUG >= 3:
|
||||
print(f"Cache miss for {t}, {tobytes=}")
|
||||
print(f"Cache miss for {t}")
|
||||
cache_misses = info.misses
|
||||
return ret
|
||||
|
||||
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
|
||||
# TODO: use dtypes.float16 for FLOAT16
|
||||
DTYPE_MAP: Dict[TensorProto.DataType, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8, TensorProto.UINT16:dtypes.uint16,
|
||||
TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64, TensorProto.BOOL:dtypes.bool,
|
||||
TensorProto.FLOAT16:dtypes.float, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32, TensorProto.UINT64:dtypes.uint64,
|
||||
TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float, TensorProto.FLOAT8E4M3FNUZ:dtypes.float,
|
||||
TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
|
||||
# TODO: use real float16
|
||||
# src: onnx/mapping.py
|
||||
DTYPE_MAP: Dict[TensorProto.DataType | int, DType] = {
|
||||
TensorProto.FLOAT:dtypes.float32, TensorProto.UINT8:dtypes.uint8, TensorProto.INT8:dtypes.int8,
|
||||
TensorProto.UINT16:dtypes.uint16, TensorProto.INT16:dtypes.int16, TensorProto.INT32:dtypes.int32, TensorProto.INT64:dtypes.int64,
|
||||
TensorProto.BOOL:dtypes.bool, TensorProto.FLOAT16:dtypes.float32, TensorProto.DOUBLE:dtypes.double, TensorProto.UINT32:dtypes.uint32,
|
||||
TensorProto.UINT64:dtypes.uint64, TensorProto.BFLOAT16:dtypes.bfloat16, TensorProto.FLOAT8E4M3FN:dtypes.float,
|
||||
TensorProto.FLOAT8E4M3FNUZ:dtypes.float, TensorProto.FLOAT8E5M2:dtypes.float, TensorProto.FLOAT8E5M2FNUZ:dtypes.float
|
||||
}
|
||||
def dtype_parse(onnx_dtype: TensorProto.DataType | int) -> DType:
|
||||
if onnx_dtype not in DTYPE_MAP: raise NotImplementedError(f"onnx dtype {TensorProto.DataType.Name(onnx_dtype)} is not supported")
|
||||
return DTYPE_MAP[onnx_dtype] if is_dtype_supported(DTYPE_MAP[onnx_dtype]) else dtypes.float
|
||||
|
||||
# src: onnx/onnx_ml_pb2.pyi
|
||||
ATTRIBUTE_MAP: Dict[AttributeProto.AttributeType, Callable[[AttributeProto], Any]] = {
|
||||
AttributeProto.FLOAT: lambda a: float(a.f), AttributeProto.INT: lambda a: int(a.i),
|
||||
AttributeProto.STRING: lambda a: a.s.decode("utf-8"), AttributeProto.TENSOR: lambda a: buffer_parse(a.t),
|
||||
AttributeProto.FLOATS: lambda a: tuple(float(x) for x in a.floats), AttributeProto.INTS: lambda a: tuple(int(x) for x in a.ints),
|
||||
AttributeProto.STRINGS: lambda a: tuple(x.decode("utf-8") for x in a.strings)
|
||||
}
|
||||
def attribute_parse(onnx_attribute: AttributeProto):
|
||||
if onnx_attribute.type not in ATTRIBUTE_MAP:
|
||||
raise NotImplementedError(f"attribute with type {AttributeProto.AttributeType.Name(onnx_attribute.type)} is not supported")
|
||||
return ATTRIBUTE_MAP[onnx_attribute.type](onnx_attribute)
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
||||
return Tensor(dat, dtype=dtype_parse(inp.data_type), requires_grad=False).reshape(tuple(inp.dims))
|
||||
if len(inp.raw_data) > 0:
|
||||
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).copy().reshape(tuple(inp.dims)),
|
||||
dtype=dtype_parse(inp.data_type), requires_grad=False)
|
||||
raise NotImplementedError(f"buffer with data type {TensorProto.DataType.Name(inp.data_type)} is not supported")
|
||||
|
||||
onnx_ops = importlib.import_module('extra.onnx_ops')
|
||||
|
||||
ONNXLIMIT = getenv("ONNXLIMIT", -1)
|
||||
|
||||
def get_run_onnx(onnx_model: ModelProto):
|
||||
def type_parse(type_proto: TypeProto):
|
||||
ret = []
|
||||
while True:
|
||||
attr = type_proto.WhichOneof('value')
|
||||
if attr == 'tensor_type':
|
||||
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
|
||||
elif not ret:
|
||||
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
|
||||
else:
|
||||
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
|
||||
return tuple(ret)
|
||||
elif attr == 'sequence_type':
|
||||
type_proto = getattr(type_proto, attr).elem_type
|
||||
ret.append(1)
|
||||
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
|
||||
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
|
||||
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
|
||||
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
|
||||
else: raise AttributeError(f"unknown attr: {attr}, {type_proto}")
|
||||
|
||||
def buffer_parse(inp: TensorProto) -> Tensor:
|
||||
if inp.data_type not in DTYPE_MAP:
|
||||
raise NotImplementedError(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
|
||||
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
|
||||
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
|
||||
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
|
||||
if len(inp.raw_data) > 0:
|
||||
data = np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(_to_np_dtype(dtype)).copy()
|
||||
return Tensor(data.reshape(tuple(inp.dims)), requires_grad=False)
|
||||
return Tensor(None, requires_grad=False)
|
||||
|
||||
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
|
||||
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
|
||||
if a.type == AttributeProto.FLOAT: return float(a.f)
|
||||
elif a.type == AttributeProto.INT: return int(a.i)
|
||||
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
|
||||
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
|
||||
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
|
||||
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
|
||||
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
|
||||
elif a.type == AttributeProto.GRAPH: raise NotImplementedError(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
|
||||
else: raise RuntimeError(f"can't parse {a.type} {a}")
|
||||
|
||||
tensors: Dict[str, Tensor] = {}
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
tensors[inp.name] = buffer_parse(inp)
|
||||
|
||||
# preparse the attributes
|
||||
attribute_dict = {}
|
||||
domain = ""
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = {x.name:attribute_parse(x) for x in n.attribute}
|
||||
if n.domain: domain = n.domain
|
||||
# model initialization data
|
||||
model_parameters = {inp.name:buffer_parse(inp) for inp in onnx_model.graph.initializer}
|
||||
model_attributes = {num:{x.name:attribute_parse(x) for x in n.attribute} for num,n in enumerate(onnx_model.graph.node)}
|
||||
|
||||
# model descriptions
|
||||
# TODO: need a better way of controlling training vs non-training
|
||||
is_onnx_preview_training = any(n.HasField("domain") and n.domain == "ai.onnx.preview.training" for n in onnx_model.graph.node)
|
||||
onnx_model_version = onnx_model.opset_import[0].version
|
||||
|
||||
# mapping from onnx ops to tensor.py ops
|
||||
tensor_methods = {
|
||||
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
|
||||
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
|
||||
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
|
||||
}
|
||||
|
||||
# src: https://onnx.ai/onnx/repo-docs/IR.html#input-output-data-types
|
||||
# parses and validates inputs based on their shape and dtype specified by model
|
||||
def prepare_input(user_input:Any, model_input:ValueInfoProto):
|
||||
type_proto = model_input.type
|
||||
if type_proto.HasField("optional_type"):
|
||||
if user_input is None: return Tensor(None)
|
||||
type_proto = type_proto.optional_type.elem_type
|
||||
if type_proto.HasField("sequence_type"):
|
||||
if not isinstance(user_input, Sequence): raise RuntimeError(f"{model_input.name} received {user_input}, expected sequence type")
|
||||
dtype = dtype_parse(type_proto.sequence_type.elem_type.tensor_type.elem_type)
|
||||
sequence = [Tensor(i, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(i, Tensor) else i for i in user_input]
|
||||
if not all_same(tuple(t.shape for t in sequence)): raise RuntimeError(f"shapes for {model_input.name} must be homogeneous")
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if not all(t.dtype is dtype for t in sequence): raise RuntimeError(f"{model_input.name} received wrong dtype, expected {dtype}")
|
||||
return sequence
|
||||
if type_proto.HasField("tensor_type"):
|
||||
dtype = dtype_parse(type_proto.tensor_type.elem_type)
|
||||
tensor = Tensor(user_input, dtype=dtype, requires_grad=is_onnx_preview_training) if not isinstance(user_input, Tensor) else user_input
|
||||
# TODO: need true float16 for dtype checking
|
||||
# if dtype is not tensor.dtype: raise RuntimeError(f"{model_input.name} received dtype {inp.dtype}, expected {dtype}")
|
||||
for d,onnx_dim in enumerate(type_proto.tensor_type.shape.dim):
|
||||
# NOTE: dim is a variable dimension when `dim_param` is specified, e.g. dim {dim_param: "N"} is a variable dim
|
||||
if onnx_dim.dim_param is None and onnx_dim.dim_value != user_input.shape[d]:
|
||||
raise RuntimeError(f"{model_input.name} received value {user_input.shape[d]} on dim {d}, expected {onnx_dim.dim_value}")
|
||||
return tensor
|
||||
type_field_names = [field.name for field,_ in type_proto.ListFields()]
|
||||
raise NotImplementedError(f"{model_input.name} with {type_field_names=} is not supported")
|
||||
|
||||
def run_onnx(inputs={}, debug=0):
|
||||
debug = getenv("DEBUGONNX") or debug
|
||||
input_tensors: Dict[str,Tensor|List[Tensor]] = {}
|
||||
intermediate_tensors: Dict[str,Tensor] = {}
|
||||
output_tensor_names = [x.name for x in onnx_model.graph.output]
|
||||
|
||||
# get inputs
|
||||
input_tensors: Dict[str, Tensor | List[Tensor]] = {}
|
||||
for model_input in onnx_model.graph.input:
|
||||
name = model_input.name
|
||||
if name in tensors: continue
|
||||
shape = type_parse(model_input.type)
|
||||
if name in inputs:
|
||||
if isinstance(inputs[name], Tensor):
|
||||
input_tensors[name] = inputs[name]
|
||||
elif isinstance(inputs[name], list):
|
||||
input_tensors[name] = [Tensor(i, requires_grad=False) for i in inputs[name]]
|
||||
# TODO: this is just to make training tests pass, need a principled way to handle training vs non-training
|
||||
elif domain == "ai.onnx.preview.training":
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=True)
|
||||
else:
|
||||
input_tensors[name] = Tensor(inputs[name], requires_grad=False)
|
||||
if shape: # if only input_tensor is not variable type
|
||||
ts = input_tensors[name]
|
||||
input_shape = ts.shape if isinstance(ts, Tensor) else (1, *[i.shape for i in ts])
|
||||
assert input_shape == shape, f"wrong shape for input {name}, {input_shape} isn't {shape}"
|
||||
else:
|
||||
raise RuntimeError(f"no data for {name} with shape {shape}")
|
||||
if model_input.name in inputs: input_tensors[model_input.name] = prepare_input(inputs[model_input.name], model_input)
|
||||
elif model_input.name not in model_parameters: raise RuntimeError(f"Please provide input data for {model_input.name}")
|
||||
|
||||
def fetch_tensor(x: str):
|
||||
if x in tensors: return tensors[x]
|
||||
if x in model_parameters: return model_parameters[x]
|
||||
if x in intermediate_tensors: return intermediate_tensors[x]
|
||||
if x != "": return input_tensors[x]
|
||||
return None
|
||||
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
inp: List[Tensor] = []
|
||||
if debug >= 3: print("inputs:")
|
||||
for x in n.input:
|
||||
t = fetch_tensor(x)
|
||||
if debug >= 3: print(f"\t{x} - {t}")
|
||||
inp.append(t)
|
||||
opt: Dict = attribute_dict[num]
|
||||
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
||||
inp = [fetch_tensor(x) for x in n.input]
|
||||
opt = model_attributes[num]
|
||||
|
||||
if debug >= 1: print(f"{num}: op \"{n.op_type}\" input shapes {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
|
||||
if debug >= 3: print("\tinputs:\n" + "\n".join(f"\t\t{x} - {t}" for i,(x,t) in enumerate(zip(n.input, inp))))
|
||||
|
||||
if n.op_type in tensor_methods:
|
||||
ret = getattr(Tensor, tensor_methods[n.op_type])(*inp, **opt)
|
||||
|
||||
# NOTE some ops live here because they require access to some local variables
|
||||
# have to use n.output for cases when num_outputs is absent
|
||||
if n.op_type in onnx_ops.tensor_methods:
|
||||
ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
|
||||
elif n.op_type == "Split":
|
||||
axis = opt.get("axis", 0)
|
||||
split = None if len(inp) == 1 else to_python_const(inp[1])
|
||||
if split is None:
|
||||
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
|
||||
for i in range(inp[0].shape[axis] % len(n.output)):
|
||||
split[i] += 1
|
||||
i, ret = 0, []
|
||||
arg = [None] * inp[0].ndim
|
||||
for s in split:
|
||||
arg[axis] = (i,i+s)
|
||||
ret.append(inp[0].shrink(arg=tuple(arg)))
|
||||
i = i+s
|
||||
ret = tuple(ret)
|
||||
|
||||
# need to check onnx_model_version
|
||||
elif n.op_type == "Slice":
|
||||
if onnx_model_version < 10:
|
||||
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
|
||||
else:
|
||||
starts, ends = inp[1:3]
|
||||
axes = list(range(inp[0].ndim)) if len(inp) <= 3 else to_python_const(inp[3].cast(dtypes.int32))
|
||||
steps = inp[4].cast(dtypes.int32).tolist() if len(inp) > 4 else [1]*inp[0].ndim
|
||||
starts, ends = to_python_const(starts), to_python_const(ends)
|
||||
arg = [(0,x,1) for x in inp[0].shape]
|
||||
for i, axis in enumerate(axes):
|
||||
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
|
||||
if starts[i] < 0: starts[i] += inp[0].shape[axis]
|
||||
if ends[i] < 0: ends[i] += inp[0].shape[axis]
|
||||
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
|
||||
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
|
||||
arg[axis] = (starts[i], ends[i], steps[i])
|
||||
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
|
||||
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
|
||||
else: ret = inp[0][tuple([slice(s,e,st) for s,e,st in arg])]
|
||||
|
||||
# need to call backward on intermediate_tensors
|
||||
axis, n_outputs = opt.get('axis', 0), opt.get('num_outputs') or len(n.output)
|
||||
sz = inp[0].shape[axis]
|
||||
sizes = to_python_const(inp[1]) if len(inp) == 2 else [sz // n_outputs + (1 if i < sz % n_outputs else 0) for i in range(n_outputs)]
|
||||
ret = inp[0].split(sizes, axis)
|
||||
elif n.op_type == "Gradient":
|
||||
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
|
||||
y = opt["y"]
|
||||
@@ -209,16 +163,12 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise NotImplementedError(f"op_type {n.op_type} not supported")
|
||||
|
||||
# finalization after running the op
|
||||
if not isinstance(ret, tuple): ret = (ret, )
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
|
||||
if debug >= 2: print("outputs:")
|
||||
for i in range(len(n.output)):
|
||||
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
|
||||
intermediate_tensors[n.output[i]] = ret[i]
|
||||
if num == ONNXLIMIT:
|
||||
output_tensor_names = n.output
|
||||
break
|
||||
if len(n.output) > len(ret): raise RuntimeError(f"expected output size must be less than {len(ret)}, it's {n.output}")
|
||||
for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i]
|
||||
if debug >= 2: print("\toutputs:\n" + "\n".join(f"\t\t{n.output[i]} - {ret[i]}" for i in range(len(n.output))))
|
||||
|
||||
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
|
||||
if num == ONNXLIMIT: return {name:intermediate_tensors[name] for name in n.output}
|
||||
return {x.name:intermediate_tensors[x.name] for x in onnx_model.graph.output}
|
||||
return run_onnx
|
||||
|
||||
@@ -3,13 +3,9 @@ from typing import Union, Tuple, Optional, List, Any, cast
|
||||
from tinygrad.tensor import Tensor, _broadcast_shape
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.helpers import prod, flatten
|
||||
from extra.onnx import DTYPE_MAP, to_python_const
|
||||
from extra.onnx import dtype_parse, to_python_const
|
||||
import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan","Relu",
|
||||
"Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign",
|
||||
"Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
|
||||
def Identity(x: Tensor): return x
|
||||
@@ -21,12 +17,16 @@ def LessOrEqual(x:Tensor,y:Tensor): return x <= y
|
||||
def Greater(x:Tensor,y:Tensor): return x > y
|
||||
def GreaterOrEqual(x:Tensor,y:Tensor): return x >= y
|
||||
def Equal(x:Tensor,y:Tensor): return x == y
|
||||
def BitwiseNot(x:Tensor): return ~x
|
||||
def BitwiseOr(x:Tensor, y:Tensor): return x | y
|
||||
def BitwiseAnd(x:Tensor, y:Tensor): return x & y
|
||||
def BitwiseXor(x:Tensor, y:Tensor): return x ^ y
|
||||
def Max(*data_0): return functools.reduce(Tensor.maximum, data_0)
|
||||
def Min(*data_0): return functools.reduce(Tensor.minimum, data_0)
|
||||
def Sum(*data_0): return functools.reduce(Tensor.add, data_0)
|
||||
def Mean(*data_0): return Sum(*data_0) / len(data_0)
|
||||
# NOTE: does not support saturate
|
||||
def Cast(x: Tensor, to: int, saturate=1): return x.cast(DTYPE_MAP[to])
|
||||
def Cast(x: Tensor, to: int, saturate=1): return x.cast(dtype_parse(to))
|
||||
def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_type.dtype)
|
||||
|
||||
# **************** Simple Ops ****************
|
||||
@@ -91,6 +91,14 @@ def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||
k = to_python_const(k) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
|
||||
def Slice(data: Tensor, starts:Tensor, ends:Tensor, axes:Optional[Tensor]=None, steps:Optional[Tensor]=None):
|
||||
if axes is None: axes = list(range(data.ndim))
|
||||
if steps is None: steps = [1] * data.ndim
|
||||
starts, ends, axes, steps = (to_python_const(x) for x in (starts, ends, axes, steps))
|
||||
slices = [slice(0,x,1) for x in data.shape]
|
||||
for i, axis in enumerate(axes): slices[axis] = slice(starts[i], ends[i], steps[i])
|
||||
return data[tuple(slices)]
|
||||
|
||||
def Squeeze(data: Tensor, axes):
|
||||
if isinstance(axes, Tensor): axes = to_python_const(axes)
|
||||
axes = [data._resolve_dim(x) for x in axes]
|
||||
@@ -389,11 +397,10 @@ def CenterCropPad(t: Tensor, shape: Tensor, axes=None):
|
||||
return t.shrink(tuple(shrink_arg)).pad(tuple(pad_arg))
|
||||
|
||||
def OneHot(indices: Tensor, depth: Tensor, values: Tensor, axis=-1):
|
||||
depth = to_python_const(depth)
|
||||
depth = int(to_python_const(depth))
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth, indices = depth[0] if isinstance(depth, list) else depth, (indices < 0).where(indices+depth, indices),
|
||||
if axis < 0: axis += indices.ndim + 1
|
||||
return (indices[:,None] == Tensor.arange(int(depth)).reshape((int(depth),) + (1,)*(indices.ndim-axis))).where(values[1], values[0])
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
if axis is None:
|
||||
@@ -405,7 +412,7 @@ def Compress(inp: Tensor, condition: Tensor, axis=None):
|
||||
return inp[tuple(con if i == axis else slice(None) for i in range(inp.ndim))]
|
||||
|
||||
def EyeLike(x: Tensor, dtype=None, k=0):
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=DTYPE_MAP[dtype] if dtype else x.dtype)
|
||||
ret = Tensor.eye(cast(int, min(x.shape)), dtype=dtype_parse(dtype) if dtype else x.dtype)
|
||||
return ret if x.size(0) == x.size(1) else ret.pad(tuple(None if d == ret.size(0) else (k, d-ret.size(0)-k) for d in x.shape))
|
||||
|
||||
def Upsample(X, scales, mode): return Resize(X=X, scales=scales, mode=mode)
|
||||
@@ -424,7 +431,7 @@ def DequantizeLinear(x: Tensor, x_scale: Tensor, x_zero_point: Union[Tensor, int
|
||||
def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"):
|
||||
try: import PIL.Image
|
||||
except ImportError as e: raise ImportError("Pillow must be installed to use the reference implementation of the ImageDecoder operator") from e
|
||||
img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream, True)))
|
||||
img = PIL.Image.open(io.BytesIO(to_python_const(encoded_stream)))
|
||||
if pixel_format == "BGR": return Tensor(np.array(img))[:, :, ::-1]
|
||||
if pixel_format == "RGB": return Tensor(np.array(img))
|
||||
if pixel_format == "Grayscale": return Tensor(np.array(img.convert("L"))).unsqueeze(-1) # (H, W) to (H, W, 1)
|
||||
@@ -461,8 +468,7 @@ def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Optional[Tensor]=None
|
||||
vocab_size, max_position_embeddings, type_vocab_size = word_embedding.shape[0], position_embedding.shape[0], (segment_embedding.shape[0] if compute_seg_emb else None)
|
||||
|
||||
def embedding(x:Tensor, vocab_size, weight:Tensor) -> Tensor:
|
||||
vocab_counter = Tensor.arange(vocab_size, dtype=x.dtype, requires_grad=False).expand(*x.shape, vocab_size)
|
||||
return (vocab_counter == x.unsqueeze(-1).expand(*x.shape, vocab_size)) @ weight
|
||||
return x.unsqueeze(-1).expand(*x.shape, vocab_size)._one_hot_along_dim(vocab_size) @ weight
|
||||
|
||||
# bert embedding layer
|
||||
if epsilon is None: epsilon = 1e-12
|
||||
|
||||
@@ -5,9 +5,9 @@ from test.external.process_replay.process_replay import _pmap
|
||||
|
||||
LOGOPS = os.getenv("LOGOPS", "/tmp/sops")
|
||||
|
||||
def extract_ast(*args) -> bool:
|
||||
def extract_ast(*args) -> None:
|
||||
open(LOGOPS, "a").write(str(args[0]).replace("\n", "").replace(" ", "")+"\n")
|
||||
return args[-1]
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
_pmap("kernel", extract_ast)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
export PAGE_SIZE=1
|
||||
export PYTHONPATH=.
|
||||
export LOGOPS=/tmp/ops
|
||||
export RUN_PROCESS_REPLAY=1
|
||||
rm $LOGOPS
|
||||
@@ -24,5 +25,5 @@ JIT=2 BIG=1 MPS=1 python -m pytest test/test_speed_v_torch.py
|
||||
|
||||
# extract, sort and uniq
|
||||
extra/optimization/extract_dataset.py
|
||||
sort -u /tmp/ops > /tmp/sops
|
||||
sort -u /tmp/ops > /tmp/sops
|
||||
ls -lh /tmp/ops /tmp/sops
|
||||
|
||||
2
setup.py
2
setup.py
@@ -29,7 +29,7 @@ setup(name='tinygrad',
|
||||
'triton': ["triton-nightly>=2.1.0.dev20231014192330"],
|
||||
'linting': [
|
||||
"pylint",
|
||||
"mypy==1.11.2",
|
||||
"mypy==1.13.0",
|
||||
"typing-extensions",
|
||||
"pre-commit",
|
||||
"ruff",
|
||||
|
||||
8
test/external/external_test_onnx_backend.py
vendored
8
test/external/external_test_onnx_backend.py
vendored
@@ -70,11 +70,16 @@ backend_test.exclude('BFLOAT16') # not supported in numpy
|
||||
# TODO: fix these with true onnx float16
|
||||
backend_test.exclude('to_FLOAT16')
|
||||
backend_test.exclude('cast_no_saturate')
|
||||
backend_test.exclude('test_dequantizelinear_e4m3fn_float16_cpu')
|
||||
backend_test.exclude('test_max_float16_cpu')
|
||||
backend_test.exclude('test_min_float16_cpu')
|
||||
|
||||
backend_test.exclude('test_pow_types_int*')
|
||||
backend_test.exclude('test_convinteger_*')
|
||||
backend_test.exclude('test_matmulinteger_*')
|
||||
|
||||
backend_test.exclude('test_dequantizelinear_int4_cpu')
|
||||
backend_test.exclude('test_dequantizelinear_uint4_cpu')
|
||||
|
||||
# we don't support indexes
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
|
||||
@@ -117,7 +122,6 @@ backend_test.exclude('test_affine_grid_3d_expanded_cpu')
|
||||
backend_test.exclude('test_range_int32_type_negative_delta_expanded_cpu')
|
||||
|
||||
# unsupported (strange) ops
|
||||
backend_test.exclude('test_bitwise_*')
|
||||
backend_test.exclude('test_blackmanwindow_*')
|
||||
backend_test.exclude('test_bernoulli_*')
|
||||
backend_test.exclude('test_det_*')
|
||||
|
||||
@@ -54,6 +54,7 @@ def diff(offset:int, name:str, fxn:Callable) -> Union[Tuple[int, int], bool]:
|
||||
# try recreate
|
||||
try:
|
||||
with Context(**{k:v for k,v in args[-2].items() if k in ContextVar._cache and k != "DEBUG"}): good = fxn(*args[:-2])
|
||||
if good is None: continue
|
||||
except Exception as e:
|
||||
logging.warning(f"FAILED TO RECREATE KERNEL {e}")
|
||||
for x in args[:-1]: logging.info(x)
|
||||
|
||||
4
test/external/speed_v_theoretical.py
vendored
4
test/external/speed_v_theoretical.py
vendored
@@ -87,7 +87,7 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
|
||||
# NOTE: tiny7 was slower than tiny12
|
||||
# TODO: why are convs so slow?!?
|
||||
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=36, amd_tflops=24)
|
||||
def test_conv_3x3_256_32_32_256_256(self): self._test_conv_3x3(256, 32, 32, 256, 256, nv_tflops=27, amd_tflops=24)
|
||||
|
||||
# theoretical is nv_tflops=165, amd_tflops=123
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv_tflops=120, amd_tflops=80)
|
||||
@@ -95,7 +95,7 @@ class TestKernelSpeed(unittest.TestCase):
|
||||
|
||||
# theoretical is nv_gbs=1008, amd_gbs=960
|
||||
def test_gemv_16384_4096(self): self._test_matmul(16384, 4096, 1, nv_gbs=840, amd_gbs=750)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=780)
|
||||
def test_gemv_4096_16384(self): self._test_matmul(4096, 16384, 1, nv_gbs=830, amd_gbs=760)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -22,7 +22,7 @@ def consec(shape, start=1):
|
||||
def set_(reference: Tensor, shape, strides, offset):
|
||||
if reference.lazydata.base.realized is None: reference.realize()
|
||||
assert reference.lazydata.base.realized, "base has to be realized before setting it to strided's base"
|
||||
strided = Tensor(reference.lazydata._view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
|
||||
strided = Tensor(reference.lazydata.view(ShapeTracker((View.create(shape=shape, strides=strides, offset=offset),))))
|
||||
assert strided.lazydata.st.real_strides() == strides, "real_strides should equal strides for strided"
|
||||
return strided
|
||||
|
||||
@@ -1062,9 +1062,9 @@ class TestIndexing(unittest.TestCase):
|
||||
numpy_testing_assert_equal_helper(a[0, one], a[zero, 1])
|
||||
|
||||
# indexing by a scalar should slice (not copy)
|
||||
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
|
||||
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
|
||||
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int32)])
|
||||
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
|
||||
|
||||
# scalar indexed with scalar
|
||||
r = Tensor.randn()
|
||||
@@ -1105,6 +1105,20 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_allclose(9.9, r, rtol=1e-7)
|
||||
'''
|
||||
|
||||
def test_getitem_casted_scalars_folding(self):
|
||||
Tensor.manual_seed(0)
|
||||
# cast of const is just another const, don't need extra kernels for this
|
||||
a = Tensor.randn(2, 3)
|
||||
one = Tensor(1, dtype=dtypes.int64)
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int32)]))
|
||||
self.assertEqual(data_ptr(a[1]), data_ptr(a[one.cast(dtypes.int16)]))
|
||||
|
||||
def test_getitem_scalars_simple_folding(self):
|
||||
a = Tensor.randn(2, 3)
|
||||
zero = Tensor(0, dtype=dtypes.int64)
|
||||
one = Tensor(1, dtype=dtypes.int64)
|
||||
self.assertEqual(data_ptr(a[0, 1]), data_ptr(a[zero, one]))
|
||||
|
||||
def test_basic_advanced_combined(self):
|
||||
# From the NumPy indexing example
|
||||
x = Tensor.arange(0, 12).reshape(4, 3)
|
||||
@@ -1555,4 +1569,4 @@ class TestNumpy(unittest.TestCase):
|
||||
'''
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -158,6 +158,37 @@ class TestReduceOpsConstFolding(unittest.TestCase):
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum())
|
||||
np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2)
|
||||
|
||||
def test_bool_zero_max(self):
|
||||
_check_ast_count(0, Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)))
|
||||
np.testing.assert_equal(Tensor.full((1, 2), True).shrink(((0, 1), (0, 0))).max((1, 0)).numpy(), False)
|
||||
|
||||
def test_zero_size_ops(self):
|
||||
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]: # lambda x:x.max() NOTE: numpy gives "reduction operation maximum which has no identity"
|
||||
_check_ast_count(0, reduceop(Tensor.empty(1, 0)))
|
||||
np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty(shape)))
|
||||
|
||||
def test_zero_size_ops_view(self):
|
||||
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
|
||||
_check_ast_count(0, reduceop(Tensor.empty(1, 0, 4).permute((1, 2, 0)).contiguous()))
|
||||
np.testing.assert_equal(reduceop(Tensor.empty(shape:=(1, 0))).numpy(), reduceop(np.empty((shape))))
|
||||
|
||||
def test_zero_size_ops_realized(self):
|
||||
for reduceop in [lambda x:x.prod(), lambda x:x.sum()]:
|
||||
_check_ast_count(0, reduceop((Tensor.randn(0, 1)+1).realize()))
|
||||
np.testing.assert_equal(reduceop((Tensor.randn(shape:=(0, 1))+1).realize()).numpy(), reduceop(np.empty(shape)))
|
||||
|
||||
def test_zero_size_realize_folded(self):
|
||||
# non contiguous folded output doesn't realize
|
||||
_check_ast_count(0, Tensor.empty(1, 0).sum())
|
||||
# contiguous folded const can still schedule
|
||||
a = Tensor.empty(1, 0).sum().contiguous()
|
||||
_check_ast_count(2, a+2)
|
||||
self.assertIsNotNone(a.lazydata.base.realized)
|
||||
np.testing.assert_equal((Tensor.empty(1, 0).sum().contiguous()+2).numpy(), 2)
|
||||
# otherwise we just fuse it
|
||||
_check_ast_count(1, (Tensor.empty(1, 0).sum()+2).contiguous())
|
||||
np.testing.assert_equal((Tensor.empty(1, 0).sum()+2).numpy(), 2)
|
||||
|
||||
def test_const_prod(self):
|
||||
_check_ast_count(0, Tensor.full((2, 3), fill_value=2).prod())
|
||||
np.testing.assert_equal(Tensor.full((2, 3), fill_value=2).prod().numpy(), 2**(2*3))
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
import operator
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from hypothesis import given, strategies as strat, settings, HealthCheck
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -86,7 +86,7 @@ def universal_test_unary(a, dtype, op):
|
||||
|
||||
def universal_test_cast(a, in_dtype, dtype):
|
||||
tensor_value = Tensor([a], dtype=in_dtype).cast(dtype)
|
||||
numpy_value = np.array([a]).astype(_to_np_dtype(dtype))
|
||||
numpy_value = np.array([a], dtype=_to_np_dtype(in_dtype)).astype(_to_np_dtype(dtype))
|
||||
np.testing.assert_equal(tensor_value.numpy(), numpy_value)
|
||||
|
||||
def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
|
||||
@@ -178,6 +178,28 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
||||
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
float_strat = float_strat.filter(lambda x: 0 < x < dtypes.max(unsigned_dtype))
|
||||
universal_test_cast(a.draw(float_strat), float_dtype, unsigned_dtype)
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.filter_too_much])
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_overflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
overflow_strat = float_strat.filter(lambda x: x > dtypes.max(unsigned_dtype) and x <= dtypes.max(dtypes.int32))
|
||||
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
|
||||
|
||||
@given(strat.data(), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned_underflow(self, a, float_dtype, unsigned_dtype):
|
||||
if not is_dtype_supported(float_dtype, Device.DEFAULT): float_dtype = dtypes.float32
|
||||
float_strat = {dtypes.float16: ht.float16, dtypes.float32: ht.float32, dtypes.float64: ht.float64}[float_dtype]
|
||||
overflow_strat = float_strat.filter(lambda x: x < 0 and x >= dtypes.min(dtypes.int32))
|
||||
universal_test_cast(a.draw(overflow_strat), float_dtype, unsigned_dtype)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_unsafe_cast_float_to_int_failure(self):
|
||||
val = float(dtypes.max(dtypes.int32) - 1)
|
||||
|
||||
@@ -398,6 +398,14 @@ class TestJit(unittest.TestCase):
|
||||
for i in range(5):
|
||||
np.testing.assert_equal(g(Tensor([i]*3), Tensor.ones(3), Tensor.zeros(3)).numpy(), np.array([i+1]*3))
|
||||
|
||||
def test_jitted_clone(self):
|
||||
def f(a): return a.clone().realize()
|
||||
jf = TinyJit(f)
|
||||
for _ in range(5):
|
||||
a = Tensor.randn(10, 10, device=Device.DEFAULT).realize()
|
||||
ja = jf(a)
|
||||
np.testing.assert_allclose(a.numpy(), ja.numpy(), atol=1e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL", "NV", "AMD"}, "no GPU CI")
|
||||
def test_jitted_transfers(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
@@ -483,5 +491,52 @@ class TestJitInsideJit(unittest.TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "having TinyJit inside another TinyJit is not supported"):
|
||||
g(Tensor([1])).realize()
|
||||
|
||||
class TestCopyInsideJit(unittest.TestCase):
|
||||
def test_copy_inside_jit(self):
|
||||
@TinyJit
|
||||
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
|
||||
for _ in range(5):
|
||||
# create a Tensor in CLANG
|
||||
a = Tensor.rand(16,16,device="CLANG").realize()
|
||||
b = Tensor.rand(16,16).realize()
|
||||
out = add(a,b)
|
||||
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
|
||||
class TestJitPrune(unittest.TestCase):
|
||||
def test_simple_prune(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_noprune.captured.jit_cache) == 2
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16).realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
assert len(w2_prune.captured.jit_cache) == 1
|
||||
|
||||
def test_prune_w_copy_correct(self):
|
||||
weights = Tensor.rand(16).realize()
|
||||
def w2(x) -> Tensor: return (weights*2).contiguous() + x.to(Device.DEFAULT)
|
||||
w2_noprune = TinyJit(w2)
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -63,12 +63,12 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
|
||||
def test_const_dtype(self):
|
||||
lb: LazyBuffer = Tensor([1], dtype=dtypes.int).lazydata
|
||||
assert lb.const_like(1).base.arg == 1
|
||||
assert type(lb.const_like(1).base.arg) is int
|
||||
assert lb.const_like(1).const_arg == 1
|
||||
assert type(lb.const_like(1).const_arg) is int
|
||||
|
||||
lb: LazyBuffer = Tensor([1], dtype=dtypes.float).lazydata
|
||||
assert lb.const_like(1).base.arg == 1.0
|
||||
assert type(lb.const_like(1).base.arg) is float
|
||||
assert lb.const_like(1).const_arg == 1.0
|
||||
assert type(lb.const_like(1).const_arg) is float
|
||||
|
||||
def test_forced_realized_alu(self):
|
||||
a = Tensor.randn(2, 2).realize()
|
||||
|
||||
@@ -108,6 +108,25 @@ class TestLinearizer(unittest.TestCase):
|
||||
if skip and i in skip: continue
|
||||
assert ranges[i-1] != u, f"multireduce nested the ranges! {ranges[i-1], {u}}"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_const_alu_indexing(self):
|
||||
st = ShapeTracker.from_shape((4,)).to_uop()
|
||||
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
|
||||
op = load+UOp.const(dtypes.float, 1.0)*UOp.const(dtypes.float, -1)
|
||||
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4,).realize()
|
||||
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1*-1], opts=[])
|
||||
|
||||
def test_const_alu_indexing_one_const_fine(self):
|
||||
st = ShapeTracker.from_shape((4,)).to_uop()
|
||||
load = UOp.load(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), st, dtype=dtypes.float)
|
||||
op = load+UOp.const(dtypes.float, 1.0)
|
||||
store = UOp.store(UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), st, op)
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4,).realize()
|
||||
helper_linearizer_ast(store.sink(), [x], wanna_output=[x.numpy()+1], opts=[])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||
|
||||
@@ -670,8 +670,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
t = Tensor.zeros(16, 16).contiguous()
|
||||
t.shard_(devices, axis=0)
|
||||
assert all([lb is lb.base and lb.buffer.base.size == 4 * 16 for lb in t.lazydata.lbs])
|
||||
t.shard_(devices, axis=0).realize()
|
||||
assert all([lb is lb.base and lb.realized.base.size == 4 * 16 for lb in t.lazydata.lbs])
|
||||
|
||||
def test_clone(self):
|
||||
t = Tensor.rand(16, 16).shard(devices_2, axis=None)
|
||||
|
||||
185
test/test_ops.py
185
test/test_ops.py
@@ -36,7 +36,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
try:
|
||||
assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}"
|
||||
assert tinygrad_output.dtype == torch_output.dtype, f"dtype mismatch: tinygrad={tinygrad_output.dtype} | torch={torch_output.dtype}"
|
||||
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
|
||||
if np.issubdtype(tinygrad_output.dtype, np.floating):
|
||||
np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol)
|
||||
else:
|
||||
np.testing.assert_equal(tinygrad_output, torch_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}")
|
||||
|
||||
@@ -71,6 +74,9 @@ def prepare_test_op(low, high, shps, vals, forward_only=False):
|
||||
np.random.seed(0)
|
||||
np_data = [np.random.uniform(low=low, high=high, size=size).astype(_to_np_dtype(dtypes.default_float)) for size in shps]
|
||||
ts = [torch.tensor(data, requires_grad=(not forward_only)) for data in np_data]
|
||||
for i in range(len(ts)):
|
||||
# NOTE: torch default int64 for python ints input
|
||||
if ts[i].dtype == torch.int64: ts[i] = ts[i].type(torch.int32)
|
||||
tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts]
|
||||
return ts, tst
|
||||
|
||||
@@ -312,8 +318,7 @@ class TestOps(unittest.TestCase):
|
||||
def _test_cmp(self, fxn, reverse=True):
|
||||
# test different dtypes
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
|
||||
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
|
||||
# test broadcasting
|
||||
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
|
||||
@@ -500,11 +505,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
|
||||
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
|
||||
vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
if is_dtype_supported(dtypes.uint64):
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
|
||||
np.testing.assert_equal(x.numpy(), 2**64 - 1)
|
||||
# 1 // 0 is device dependent, but it should not raise
|
||||
Tensor([1]).idiv(1).realize()
|
||||
if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI
|
||||
# ... because if might be in a where branch that the output is well defined
|
||||
t = Tensor([-1, 0, 1, 2])
|
||||
np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0])
|
||||
|
||||
def test_scalar_div(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1)
|
||||
@@ -548,6 +560,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: x**1.2, low=-30, high=-27)
|
||||
a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True)
|
||||
helper_test_op([], lambda: b**1.1, lambda: a**1.1)
|
||||
|
||||
def test_pow_const(self):
|
||||
helper_test_op([(45,65)], lambda x: x**1.0)
|
||||
helper_test_op([(45,65)], lambda x: x**-1.0)
|
||||
@@ -561,6 +574,25 @@ class TestOps(unittest.TestCase):
|
||||
# TODO: fix backward, should be nan
|
||||
helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True)
|
||||
|
||||
def test_pow_int(self):
|
||||
def _test(base, exponent): helper_test_op(None, lambda x,y: x**y, vals=[base, exponent], forward_only=True)
|
||||
|
||||
for base in ([1, 2, 3], [-1, -2, -3]):
|
||||
for exponent in ([2, 3, 4], [-2, -3, -4]):
|
||||
_test(base, exponent)
|
||||
# NOTE: torch 0 ** -1 is 0
|
||||
_test([0, 0, 0], [0, 1, 2])
|
||||
|
||||
np.testing.assert_equal((Tensor(11) ** Tensor(7)).item(), 11 ** 7)
|
||||
np.testing.assert_equal((Tensor([11]) ** Tensor(7)).item(), 11 ** 7)
|
||||
# TODO: fix non-precise int pow
|
||||
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor(11) ** Tensor([7])).item(), 11 ** 7)
|
||||
with self.assertRaises(AssertionError): np.testing.assert_equal((Tensor([11]) ** Tensor([7])).item(), 11 ** 7)
|
||||
|
||||
# pow to a const int
|
||||
helper_test_op([], lambda: torch.tensor([2], dtype=torch.int) ** torch.tensor(-2, dtype=torch.int),
|
||||
lambda: Tensor([2]) ** Tensor(-2), forward_only=True)
|
||||
|
||||
def test_sqrt(self):
|
||||
helper_test_op([(45,65)], lambda x: x.sqrt())
|
||||
helper_test_op([()], lambda x: x.sqrt())
|
||||
@@ -784,11 +816,6 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish)
|
||||
helper_test_op([()], torch.nn.functional.mish, Tensor.mish)
|
||||
|
||||
def test_multinomial(self):
|
||||
# NOTE: this is random, so it has a very large atol
|
||||
helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1).type(torch.int32),
|
||||
lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.)
|
||||
|
||||
def test_small_cumsum(self):
|
||||
helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0))
|
||||
def test_simple_cumsum(self):
|
||||
@@ -832,17 +859,27 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
def test_argmax(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy())
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[2, 2]])
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[1, 2, 2]])
|
||||
np.testing.assert_equal(Tensor([2,2]).argmax().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([1,2,2]).argmax().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(0, False).type(torch.int32), lambda x: x.argmax(0, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, False).type(torch.int32), lambda x: x.argmax(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmax(1, True).type(torch.int32), lambda x: x.argmax(1, True), forward_only=True)
|
||||
# regression test for bitwise_not then argmax
|
||||
helper_test_op(None, lambda x: (~x).argmax().type(torch.int32), lambda x: (~x).argmax(), forward_only=True, vals=[[2, 2]])
|
||||
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[-2**31, 0]])
|
||||
# NOTE: torch does not support this on bool
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmax().type(torch.int32), lambda x: x.argmax(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_argmin(self):
|
||||
# check if it returns the first index for multiple occurences
|
||||
self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy())
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[2, 2]])
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[3, 2, 2]])
|
||||
np.testing.assert_equal(Tensor([2,2]).argmin().numpy(), np.array(0))
|
||||
np.testing.assert_equal(Tensor([3,2,2]).argmin().numpy(), np.array(1))
|
||||
helper_test_op([(10,20)], lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True)
|
||||
@@ -850,6 +887,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, False).type(torch.int32), lambda x: x.argmin(1, False), forward_only=True)
|
||||
helper_test_op([(10,20)], lambda x: x.argmin(1, True).type(torch.int32), lambda x: x.argmin(1, True), forward_only=True)
|
||||
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[-2**31, 0]])
|
||||
# NOTE: torch does not support this on bool
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).argmin().type(torch.int32), lambda x: x.argmin(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_einsum(self):
|
||||
# matrix transpose
|
||||
helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a))
|
||||
@@ -1085,11 +1128,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
|
||||
helper_test_op([()], lambda x: x.min())
|
||||
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.min(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
def test_max(self):
|
||||
helper_test_op([(45,3)], lambda x: x.max())
|
||||
@@ -1098,11 +1140,10 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
|
||||
helper_test_op([()], lambda x: x.max())
|
||||
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[0, -2**31]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[-2**31, 0]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[False, True]])
|
||||
helper_test_op(None, lambda x: x.max(), forward_only=True, vals=[[True, False]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_any(self):
|
||||
@@ -1198,17 +1239,18 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: Tensor.stack(*x.std_mean(correction=5)))
|
||||
helper_test_op([(15,25,35)], lambda x: torch.stack(torch.std_mean(x, keepdim=True, correction=0)),
|
||||
lambda x: Tensor.stack(*x.std_mean(keepdim=True, correction=0)))
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: torch.stack(torch.std_mean(x, axis=(1,3))),
|
||||
lambda x: Tensor.stack(*x.std_mean(axis=(1,3))))
|
||||
helper_test_op([(3,4,5,6)], lambda x: torch.stack(torch.std_mean(x, axis=(1,2))),
|
||||
lambda x: Tensor.stack(*x.std_mean(axis=(1,2))))
|
||||
|
||||
@unittest.skip("TODO: this fails because of loaded nan in mul folding")
|
||||
def test_std_mean_loaded_nan(self):
|
||||
helper_test_op([(1,0,3,0,5)], lambda x: torch.stack(torch.std_mean(x, axis=(1,3))),
|
||||
lambda x: Tensor.stack(*x.std_mean(axis=(1,3))))
|
||||
def test_softmax(self):
|
||||
# exceed per kernel buffer limit with backward
|
||||
forward_only = (Device.DEFAULT == "WEBGPU")
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only)
|
||||
helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7)
|
||||
def test_softmax_other_axis(self):
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7)
|
||||
helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7)
|
||||
@@ -2033,6 +2075,20 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation))
|
||||
|
||||
def test_max_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True))
|
||||
|
||||
def test_max_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.max_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
shape = (32,2,111,28)
|
||||
for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]:
|
||||
@@ -2062,11 +2118,53 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=False), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True))
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad(self):
|
||||
shape = (1,1,6,6)
|
||||
for ksz in [(3,3), 3, (3,2), 4]:
|
||||
with self.subTest(kernel_size=ksz):
|
||||
helper_test_op([shape],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, stride=3, ceil_mode=True, count_include_pad=True), rtol=1e-5)
|
||||
|
||||
def test_avg_pool2d_ceil_mode_include_pad_output_size_reduce_by_one(self):
|
||||
# sliding window ignored from end region
|
||||
helper_test_op([(1,1,5,5)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(3,3), stride=3, padding=1, ceil_mode=True, count_include_pad=True))
|
||||
|
||||
def test_global_avg_pool2d(self):
|
||||
helper_test_op([(32,2,111,28)],
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5)
|
||||
|
||||
# TODO: linearizer block error
|
||||
@unittest.expectedFailure
|
||||
def test_avg_pool3d_failure(self):
|
||||
with Context(NOOPT=0):
|
||||
helper_test_op([(1,1,16,16,16)],
|
||||
lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), rtol=1e-5, forward_only=True)
|
||||
|
||||
def test_avg_pool3d_noopt(self):
|
||||
with Context(NOOPT=1):
|
||||
helper_test_op([(1,1,16,16,16)],
|
||||
lambda x: torch.nn.functional.avg_pool3d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=(8,8,8), stride=5, padding=1, count_include_pad=False), rtol=1e-5, forward_only=True)
|
||||
|
||||
def test_interpolate_linear(self):
|
||||
for in_sz, out_sz in [((52,),(29,)), ((29,),(52,))]:
|
||||
helper_test_op([(2,3)+in_sz],
|
||||
@@ -2191,7 +2289,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(1,128), (128,128)], lambda x,y: (x@y).relu())
|
||||
|
||||
@unittest.skip("this test is broken #862")
|
||||
def test_max_inf(self):
|
||||
def test_max_nan(self):
|
||||
n = Tensor([1, float("nan")]).max().numpy()
|
||||
assert math.isnan(n.item()), f"{n.item()} is not nan"
|
||||
|
||||
@@ -2429,45 +2527,39 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
|
||||
lambda x,y: x.cross_entropy(y, label_smoothing=ls))
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss(self):
|
||||
helper_test_op([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_3d(self):
|
||||
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32)), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_reductions(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long), reduction=r),
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), reduction=r), forward_only=True)
|
||||
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), reduction=r), forward_only=True)
|
||||
self.helper_test_exception([(32,10), (32)],
|
||||
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
|
||||
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
|
||||
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.int32), reduction="typo"), expected=ValueError)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_weight(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10), (32), (10)],
|
||||
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
|
||||
weight=z, reduction=r),
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_3d_weight(self):
|
||||
for r in ("mean", "sum", "none"):
|
||||
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
|
||||
lambda x,y,z: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long),
|
||||
weight=z, reduction=r),
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
|
||||
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.int32), weight=z, reduction=r), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_nll_loss_ignore_index(self):
|
||||
logits = [[2.0, 0.5, -1.0],
|
||||
[1.5, 2.5, -0.5],
|
||||
@@ -2475,7 +2567,7 @@ class TestOps(unittest.TestCase):
|
||||
targets = [0, 1, 2]
|
||||
helper_test_op(None, lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1),
|
||||
torch.clip(y,0).type(torch.long), ignore_index=1),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0).cast(dtypes.long), ignore_index=1),
|
||||
lambda x,y: x.log_softmax().nll_loss(y.clip(0), ignore_index=1),
|
||||
forward_only=True, vals=[logits, targets])
|
||||
|
||||
def test_one_hot(self):
|
||||
@@ -2497,8 +2589,7 @@ class TestOps(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
|
||||
def test_cast(self):
|
||||
helper_test_op([(3, 3)], lambda x: x.float())
|
||||
if is_dtype_supported(dtypes.long):
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
|
||||
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
|
||||
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
|
||||
@@ -2508,7 +2599,6 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
@unittest.skip('this is broken for negative numbers')
|
||||
def test_cast(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
|
||||
|
||||
@@ -2533,7 +2623,6 @@ class TestOpsUint8(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
|
||||
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
|
||||
def test_min(self):
|
||||
helper_test_op(None,
|
||||
lambda x: x.type(torch.uint8).min(),
|
||||
|
||||
@@ -2,6 +2,7 @@ import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
@@ -24,10 +25,29 @@ class TestPickle(unittest.TestCase):
|
||||
pickle.dumps(sym)
|
||||
|
||||
def test_pickle_realized_tensor(self):
|
||||
print("** init")
|
||||
t = Tensor.rand(10, 10).realize()
|
||||
st = pickle.dumps(t)
|
||||
t_values = t.numpy()
|
||||
del t # free buffers
|
||||
print("** post pickle")
|
||||
init = GlobalCounters.kernel_count
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
np.testing.assert_equal(t_values, t2.numpy())
|
||||
# expect at most one COPY kernel
|
||||
self.assertLessEqual(GlobalCounters.kernel_count-init, 1)
|
||||
|
||||
def test_pickle_realized_tensor_alt(self):
|
||||
print("** init")
|
||||
t = Tensor.rand(10, 10).to("CLANG").realize()
|
||||
st = pickle.dumps(t)
|
||||
t_values = t.numpy()
|
||||
del t # free buffers
|
||||
print("** post pickle")
|
||||
init = GlobalCounters.kernel_count
|
||||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t_values, t2.numpy())
|
||||
self.assertEqual(GlobalCounters.kernel_count-init, 0)
|
||||
|
||||
def test_pickle_unrealized_tensor(self):
|
||||
t = Tensor.ones(10, 10)
|
||||
|
||||
@@ -13,12 +13,12 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites, view_supported_devices
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, do_realize
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
@@ -1344,6 +1344,7 @@ class TestSchedule(unittest.TestCase):
|
||||
Tensor.ones(5, 5).contiguous().schedule()
|
||||
self.assertEqual(GlobalCounters.mem_used-base, 0)
|
||||
|
||||
@unittest.skip("TODO: this is consistently creating non reproducible failures")
|
||||
def test_schedule_mem_used_with_inputs(self):
|
||||
base = GlobalCounters.mem_used
|
||||
x = Tensor.ones(256).contiguous().realize()
|
||||
@@ -1924,7 +1925,7 @@ class TestView(unittest.TestCase):
|
||||
a = UOp(Ops.VIEW, dtypes.float, (UOp.new_buffer(Device.DEFAULT, 121, dtypes.float), UOp(Ops.EMPTY, dtypes.float)), st)
|
||||
b = a.pad(pad_arg:=((0, 0), (0, 0), (18, 0)))
|
||||
self.assertEqual(b.st, st.pad(pad_arg))
|
||||
self.assertIs(b, b.const_like(0))
|
||||
self.assertIs(b.base.src[1], UOp.const(dtypes.float, 0))
|
||||
|
||||
def test_partial_mask(self):
|
||||
# partial masked out does not degrade into CONST
|
||||
@@ -1955,8 +1956,8 @@ class TestBigGraph(unittest.TestCase):
|
||||
|
||||
def test_sink_childless_const_alt_expanded(self):
|
||||
# this is a real STORE of CONST (post expand)
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
out = UOp(Ops.VIEW, dtypes.int, (UOp(Ops.BUFFER, dtypes.int.ptr(), (), 0), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
|
||||
y = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), UOp.const(dtypes.int, 0)), ShapeTracker.from_shape(()))
|
||||
out = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 2, dtypes.int), y.reshape((1,)).expand((2,)).contiguous(),), ShapeTracker.from_shape((2,)))
|
||||
big_graph = big_graph_rewrite(out.sink(), realizes:={})
|
||||
self.assertIs(big_graph, out.sink())
|
||||
self.assertEqual(len(realizes), 1)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.lazy import view_supported_devices
|
||||
from tinygrad.ops import view_supported_devices
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
|
||||
class TestSubBuffer(unittest.TestCase):
|
||||
|
||||
@@ -573,7 +573,7 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
a = t.reshape(0)
|
||||
assert a.shape == (0,)
|
||||
np.testing.assert_equal(a.numpy(), np.zeros((0,)))
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
# cannot reshape from size 0 to size 1
|
||||
a = t.reshape(())
|
||||
|
||||
|
||||
@@ -405,6 +405,14 @@ class TestUOpMethod(unittest.TestCase):
|
||||
self.assertEqual(const._device, None)
|
||||
with self.assertRaises(AssertionError): const.device
|
||||
|
||||
def test_const_arg(self):
|
||||
var = UOp.variable("a", 1, 10)
|
||||
with self.assertRaises(AssertionError): UOp.const(dtypes.int, var).const_arg
|
||||
const = UOp.const(dtypes.int, 1)
|
||||
self.assertEqual(const.const_arg, 1)
|
||||
tensor_const = UOp(Ops.VIEW, dtypes.int, (UOp.new_buffer(Device.DEFAULT, 1, dtypes.int), const), ShapeTracker.from_shape(()))
|
||||
self.assertEqual(tensor_const.const_arg, 1)
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(Ops.CONST, dtypes.float, (), 2.0) + UOp(Ops.CONST, dtypes.float, (), 3.0)
|
||||
|
||||
@@ -13,7 +13,7 @@ def time_tensor_numpy(out:Tensor):
|
||||
|
||||
N = 4096
|
||||
class TestZeroCopy(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "CPU", "METAL"}, "device isn't zero copy")
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "METAL"}, "device isn't zero copy")
|
||||
def test_zero_copy_from_default_to_cpu(self):
|
||||
demo = Tensor.rand(1).realize()
|
||||
t1 = time_tensor_numpy(demo)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from extra.export_model import export_model, EXPORT_SUPPORTED_DEVICE
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad import dtypes
|
||||
import json
|
||||
|
||||
class MockMultiInputModel:
|
||||
@@ -45,6 +46,24 @@ class TextModelExport(unittest.TestCase):
|
||||
for i, exported_output in enumerate(prg["outputs"]):
|
||||
assert outputs[i].dtype.name == exported_output["dtype"], f"Model and exported output dtype don't match: mdl={outputs[i].dtype.name}, prg={exported_output['dtype']}" # noqa: E501
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "WEBGPU", "Testing WebGPU specific model export behavior")
|
||||
class TextModelExportWebGPU(unittest.TestCase):
|
||||
def test_exported_input_output_dtypes(self):
|
||||
class MyModel:
|
||||
def forward(self, *inputs): return tuple([(inp+2).cast(inp.dtype) for inp in inputs])
|
||||
model = MyModel()
|
||||
# [:-1] because "ulong" and "long" is not supported
|
||||
inputs = [Tensor.randn(2, dtype=dt) for dt in dtypes.uints[:-1] + dtypes.sints[:-1] + (dtypes.bool, dtypes.float)]
|
||||
prg, _, _, _ = export_model(model, "webgpu", *inputs)
|
||||
expected_buffer_types = ["Uint"]*len(dtypes.uints[:-1]) + ["Int"]*len(dtypes.sints[:-1]) + ["Int", "Float"]
|
||||
for i, expected_buffer_type in enumerate(expected_buffer_types):
|
||||
dt = inputs[i].dtype
|
||||
expected_arr_prefix = f"{expected_buffer_type}{dt.itemsize*8}"
|
||||
# test input buffers
|
||||
self.assertIn(f"new {expected_arr_prefix}Array(gpuWriteBuffer{i}.getMappedRange()).set(_input{i});", prg)
|
||||
# test output buffers
|
||||
self.assertIn(f"const resultBuffer{i} = new {expected_arr_prefix}Array(gpuReadBuffer{i}.size/{dt.itemsize});", prg)
|
||||
self.assertIn(f"resultBuffer{i}.set(new {expected_arr_prefix}Array(gpuReadBuffer{i}.getMappedRange()));", prg)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
15
test/testextra/test_f16_decompress.py
Normal file
15
test/testextra/test_f16_decompress.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import unittest
|
||||
from extra.f16_decompress import u32_to_f16
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.device import Device, is_dtype_supported
|
||||
from tinygrad import dtypes
|
||||
import numpy as np
|
||||
|
||||
class TestF16Decompression(unittest.TestCase):
|
||||
def test_u32_to_f16(self):
|
||||
a = Tensor.randn(50, dtype=dtypes.float16, device=None if is_dtype_supported(dtypes.float16) else "CLANG:0")
|
||||
f16_as_u32 = a.bitcast(dtypes.uint32) if is_dtype_supported(dtypes.float16) else a.bitcast(dtypes.uint32).to(Device.DEFAULT)
|
||||
f16 = u32_to_f16(f16_as_u32)
|
||||
ref = a.numpy()
|
||||
out = f16.numpy().astype(np.float16)
|
||||
np.testing.assert_allclose(out, ref)
|
||||
@@ -134,15 +134,23 @@ class TestSafetensors(unittest.TestCase):
|
||||
for k in f.keys():
|
||||
np.testing.assert_array_equal(f.get_tensor(k).numpy(), state_dict[k].numpy())
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
def _test_huggingface_enet_safetensors(self, fn):
|
||||
state_dict = safe_load(fn)
|
||||
assert len(state_dict.keys()) == 244
|
||||
assert 'blocks.2.2.se.conv_reduce.weight' in state_dict
|
||||
assert state_dict['blocks.0.0.bn1.num_batches_tracked'].numpy() == 276570
|
||||
assert state_dict['blocks.2.0.bn2.num_batches_tracked'].numpy() == 276570
|
||||
|
||||
def test_huggingface_enet_safetensors(self):
|
||||
# test a real file
|
||||
fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
self._test_huggingface_enet_safetensors(fn)
|
||||
|
||||
def test_huggingface_enet_safetensors_fromurl(self):
|
||||
# test tensor input
|
||||
t = Tensor.from_url("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors")
|
||||
self._test_huggingface_enet_safetensors(t)
|
||||
|
||||
def test_metadata(self):
|
||||
metadata = {"hello": "world"}
|
||||
safe_save({}, temp('metadata.safetensors'), metadata)
|
||||
@@ -353,10 +361,10 @@ class TestPathTensor(unittest.TestCase):
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_with_device(self):
|
||||
t = Tensor(self.test_file, device="CPU")
|
||||
t = Tensor(self.test_file, device="CLANG")
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertEqual(t.device, "CPU")
|
||||
self.assertEqual(t.device, "CLANG")
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_empty_file(self):
|
||||
@@ -381,8 +389,8 @@ class TestPathTensor(unittest.TestCase):
|
||||
|
||||
def test_path_tensor_copy_to_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
t_cpu = t.to("CPU")
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
t_cpu = t.to("CLANG")
|
||||
self.assertEqual(t_cpu.device, "CLANG")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,8 +1,66 @@
|
||||
import unittest, tarfile, io, os, pathlib
|
||||
import unittest, tarfile, io, os, pathlib, tempfile
|
||||
import numpy as np
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.nn.state import tar_extract
|
||||
|
||||
class TestTarExtractFile(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.test_files = {
|
||||
'file1.txt': b'Hello, World!',
|
||||
'file2.bin': b'\x00\x01\x02\x03\x04',
|
||||
'empty_file.txt': b''
|
||||
}
|
||||
self.tar_path = os.path.join(self.test_dir, 'test.tar')
|
||||
with tarfile.open(self.tar_path, 'w') as tar:
|
||||
for filename, content in self.test_files.items():
|
||||
file_path = os.path.join(self.test_dir, filename)
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(content)
|
||||
tar.add(file_path, arcname=filename)
|
||||
|
||||
# Create invalid tar file
|
||||
self.invalid_tar_path = os.path.join(self.test_dir, 'invalid.tar')
|
||||
with open(self.invalid_tar_path, 'wb') as f:
|
||||
f.write(b'This is not a valid tar file')
|
||||
|
||||
def tearDown(self):
|
||||
for filename in self.test_files:
|
||||
os.remove(os.path.join(self.test_dir, filename))
|
||||
os.remove(self.tar_path)
|
||||
os.remove(self.invalid_tar_path)
|
||||
os.rmdir(self.test_dir)
|
||||
|
||||
def test_tar_extract_returns_dict(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertIsInstance(result, dict)
|
||||
|
||||
def test_tar_extract_correct_keys(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(set(result.keys()), set(self.test_files.keys()))
|
||||
|
||||
def test_tar_extract_content_size(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
self.assertEqual(len(result[filename]), len(content))
|
||||
|
||||
def test_tar_extract_content_values(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
for filename, content in self.test_files.items():
|
||||
np.testing.assert_array_equal(result[filename].numpy(), np.frombuffer(content, dtype=np.uint8))
|
||||
|
||||
def test_tar_extract_empty_file(self):
|
||||
result = tar_extract(self.tar_path)
|
||||
self.assertEqual(len(result['empty_file.txt']), 0)
|
||||
|
||||
def test_tar_extract_non_existent_file(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
tar_extract('non_existent_file.tar')
|
||||
|
||||
def test_tar_extract_invalid_file(self):
|
||||
with self.assertRaises(tarfile.ReadError):
|
||||
tar_extract(self.invalid_tar_path)
|
||||
|
||||
class TestTarExtractPAX(unittest.TestCase):
|
||||
tar_format = tarfile.PAX_FORMAT
|
||||
max_link_len = 1000_000
|
||||
|
||||
@@ -444,6 +444,8 @@ class TestSymbolic(unittest.TestCase):
|
||||
|
||||
def test_div_mod_recombine_with_gcd(self):
|
||||
b = Variable("b", 0, 100)
|
||||
exp = (16 * b + 2) % 18 + ((16 * b + 2) // 18) * 18
|
||||
self.helper_test_variable(exp, 2, 1602, "((b*16)+2)")
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_variable((30 * b + 1) % 18 + ((30 * b + 1) // 18) * 18, 1, 3001, "((b*30)+1)")
|
||||
|
||||
@@ -525,6 +527,15 @@ class TestSymbolic(unittest.TestCase):
|
||||
# not combining # TODO: can combine if one is identity element const
|
||||
self.helper_test_variable(aa+ab, 0, 6, "((a if (x<2) else b)+(a if (x<2) else 0))")
|
||||
|
||||
def test_symbolic_div(self):
|
||||
# from symbolic arange
|
||||
a = Variable("a", 1, 10)
|
||||
denominator = ((a*-2)+1)
|
||||
numerator = (((((a*2)+-1)*2)+1)*a)
|
||||
self.helper_test_variable(denominator, -19, -1, "((a*-2)+1)")
|
||||
self.helper_test_variable(numerator, 3, 390, "(a*((a*4)+-1))")
|
||||
self.helper_test_variable((numerator//denominator)<=0, 1, 1, "True")
|
||||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
MIN, MAX = 0, 10
|
||||
|
||||
@@ -32,6 +32,11 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
self.assertEqual(uop.vmin, -6)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_variable_inside_special(self):
|
||||
uop = UOp(Ops.SPECIAL, dtypes.int, arg=('gidx0', UOp(Ops.DEFINE_VAR, dtypes.int, arg=('i', 1, 10))))
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 10)
|
||||
|
||||
def test_vmin_vmax_multiplication_0_inf(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.const(dtypes.float, 0.0)
|
||||
|
||||
@@ -88,5 +88,11 @@ class TestVerifyAST(unittest.TestCase):
|
||||
st = UOp.store(buf, ShapeTracker.from_shape((32, 1)).to_uop(), r.view(r.st.expand((32, 1)))+a)
|
||||
with self.assertRaisesRegex(InvalidASTException, "swizzle"): helper_test_verify_ast(st)
|
||||
|
||||
def test_flat_const_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).cast(dtypes.float)
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a)
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -47,7 +47,7 @@ async function runTest() {
|
||||
console.log(`error from page ${message}`),
|
||||
);
|
||||
|
||||
const res = await page.goto("http://localhost:8000/examples/index.html");
|
||||
const res = await page.goto("http://localhost:8000/examples/webgpu/efficientnet/index.html");
|
||||
if (res.status() !== 200) throw new Error("Failed to load page");
|
||||
|
||||
const textSelector = await page.waitForSelector("#result");
|
||||
|
||||
@@ -63,9 +63,7 @@ class Kernel:
|
||||
print(self.ast)
|
||||
raise e
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is Ops.REDUCE_AXIS])
|
||||
self.reduceops = [x for x in self.ast.toposort if x.op is Ops.REDUCE_AXIS]
|
||||
|
||||
self.vars: List[Variable] = self.ast.variables()
|
||||
# NOTE: this requires a specific order with the [::-1], this is likely a bug
|
||||
@@ -735,7 +733,8 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
||||
st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
||||
if len(src_sts:=[sts[x] for x in uop.src if x in sts]) == 0: return None
|
||||
st = src_sts[0]
|
||||
if not all_same(shapes:=[x.shape for x in src_sts]):
|
||||
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
||||
raise AssertionError(f"found implicit expand {sizes} {shapes}")
|
||||
|
||||
@@ -124,17 +124,15 @@ powers_of_two = {2**i:i for i in range(64)}
|
||||
def get_late_rewrite_patterns(ops, force_transcendental=False):
|
||||
pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
|
||||
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests)
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
||||
if Ops.AND in ops:
|
||||
pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))),
|
||||
lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR
|
||||
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
# rewrite MUL/IDIV to SHL+SHR: x*(2**y) -> shl(x,y) and x//(2**y) -> shr(x,y)
|
||||
if Ops.SHL in ops and Ops.SHR in ops:
|
||||
pat += [
|
||||
(UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const:
|
||||
mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y)
|
||||
(UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const:
|
||||
div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y)
|
||||
(UPat.var("x", dtypes.ints)*UPat.cvar("c"), lambda c,x: x << powers_of_two[c.arg] if c.arg in powers_of_two else None),
|
||||
(UPat.var("x", dtypes.ints)//UPat.cvar("c"), lambda x,c: x >> powers_of_two[c.arg] if c.arg in powers_of_two else None)
|
||||
]
|
||||
if Ops.NEG in ops:
|
||||
pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))]
|
||||
if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))]
|
||||
@@ -241,8 +239,6 @@ arange_m = ((arange_augrng<UPat.cvar("compval"))!=UPat(Ops.CONST, name="ne", arg
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# self ASSIGN is just self
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# ASSIGN to global is just self
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.DEFINE_GLOBAL), UPat.var("x"))), lambda x: x),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST), name="vec"), lambda vec: UOp.const(vec.dtype, tuple(x.arg for x in vec.src))),
|
||||
(UPat(Ops.VECTORIZE, src=UPat(Ops.GEP, src=(UPat(name="x"),)), name="vec"), lambda vec,x: x.gep(tuple(y.arg[0] for y in vec.src))),
|
||||
@@ -288,9 +284,8 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# indexing, with cast or where
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).cast()*index_load+acc_pat), index_collapse),
|
||||
(acc_pat.assign(UPat.var("idx").eq(UPat(Ops.RANGE, name="rng")).where(index_load, UPat.const(None, 0.0))+acc_pat), index_collapse),
|
||||
# parentless reduce
|
||||
(acc_pat.assign(UPat(Ops.ADD, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
(acc_pat.assign(UPat(Ops.MAX, src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
# parentless reduce # TODO: add MUL
|
||||
(acc_pat.assign(UPat((Ops.ADD, Ops.MAX), src=[acc_pat, UPat.var("ret")], name="alu")), reduce_collapse),
|
||||
# ** self folding **
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat.var("x"),)), lambda x: x), # a DEFINE_ACC without ranges is a CONST
|
||||
(UPat(Ops.ASSIGN, src=(UPat.cvar(),UPat.var("x"))), lambda x: x), # an ASSIGN to a const is a NOOP
|
||||
|
||||
@@ -43,6 +43,7 @@ Device = _Device()
|
||||
|
||||
# **************** Buffer + Allocators ****************
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=True)
|
||||
class BufferSpec:
|
||||
# TODO: move device, size, dtype here?
|
||||
|
||||
@@ -52,7 +52,7 @@ class PtrDType(DType):
|
||||
def vec(self, sz:int) -> DType:
|
||||
assert self.v == 1, f"can't vectorize ptr {self} with size {sz}"
|
||||
if sz == 1: return self # sz=1 is a scalar
|
||||
return type(self)(*tuple(sz if f.name == 'v' else (self if f.name == '_scalar' else getattr(self, f.name)) for f in fields(self)))
|
||||
return type(self)(self.priority, self.itemsize, self.name, self.fmt, self.count, self, self._base, self.local, sz)
|
||||
def ptr(self, local=False): raise RuntimeError("can't make a pointer from a pointer")
|
||||
@property
|
||||
def vcount(self): return self.v
|
||||
@@ -99,7 +99,7 @@ class dtypes:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def max(dtype:DType):
|
||||
if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1
|
||||
if dtypes.is_int(dtype): return 2**(dtype.itemsize*8)-1+dtypes.min(dtype)
|
||||
return float("inf") if dtypes.is_float(dtype) else True
|
||||
@staticmethod
|
||||
def finfo(dtype:DType) -> Tuple[int, int]:
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferCopy, BufferXfer, CompiledRunner, Runner
|
||||
from tinygrad.engine.memory import _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
@@ -268,6 +268,8 @@ class TinyJit(Generic[ReturnType]):
|
||||
if any(b in depends for b in ei.bufs):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
depends.update(cast(Buffer, ei.bufs[out]) for out in ei.prg.p.outs)
|
||||
if isinstance(ei.prg, (BufferCopy, BufferXfer)):
|
||||
depends.add(cast(Buffer, ei.bufs[0]))
|
||||
pruned, onetime = partition(jit_cache,
|
||||
lambda ei: not isinstance(ei.prg, CompiledRunner) or any(ei.bufs[out] in depends for out in ei.prg.p.outs))
|
||||
if DEBUG >= 1: print(f"pruned from {len(jit_cache)} -> {len(pruned)} kernels")
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE
|
||||
from tinygrad.ops import exec_alu, python_alu
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops
|
||||
from tinygrad.ops import MathTrait, resolve, UOp, sint, GroupOp, Ops, view_supported_devices
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
@@ -21,7 +21,6 @@ def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]
|
||||
if enable_cache: lazycache[cache_key] = ret
|
||||
return ret
|
||||
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
||||
class LazyBuffer(MathTrait):
|
||||
def __init__(self, device:str, st:ShapeTracker, dtype:DType,
|
||||
op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
@@ -111,16 +110,20 @@ class LazyBuffer(MathTrait):
|
||||
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
|
||||
elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# TODO: applying this makes gpt2 slower
|
||||
return self.base.cast(dtype, bitcast)._view(self.st)
|
||||
return self.base.cast(dtype, bitcast).view(self.st)
|
||||
cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, None, (self,))
|
||||
|
||||
def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp)
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
@property
|
||||
def const_arg(self) -> ConstType:
|
||||
assert self.base.op is Ops.CONST and isinstance(self.base.arg, get_args(ConstType)), f"const_arg called on {self}"
|
||||
return self.base.arg
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}"
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, srcs=(self,), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer:
|
||||
# no COPY
|
||||
@@ -132,13 +135,13 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
# const doesn't have to be copied (issues with disk tensor)
|
||||
if self.is_unrealized_const():
|
||||
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
||||
return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg).view(self.st)
|
||||
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device)
|
||||
|
||||
# copy the base and apply the shapetracker on the new device
|
||||
return self.base._copy(device)._view(self.st)
|
||||
return self.base._copy(device).view(self.st)
|
||||
|
||||
def clone(self) -> LazyBuffer: return self.copy_to_device(self.device, clone=True)
|
||||
|
||||
@@ -146,7 +149,7 @@ class LazyBuffer(MathTrait):
|
||||
srcs: List[LazyBuffer] = []
|
||||
for s in (self,)+in_srcs:
|
||||
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
||||
srcs.append(root._view(s.base.contiguous_child[1]))
|
||||
srcs.append(root.view(s.base.contiguous_child[1]))
|
||||
else:
|
||||
srcs.append(s)
|
||||
if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]):
|
||||
@@ -181,15 +184,6 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]) -> LazyBuffer:
|
||||
new_shape = self.st.reduce(axis)
|
||||
# TODO: this logic should move to the scheduler
|
||||
if 0 in self.shape and 0 not in new_shape: return self.const_with_shape(identity_element(op, self.dtype), new_shape)
|
||||
|
||||
# const folding
|
||||
# TODO: fold this for symbolic?
|
||||
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
||||
if op is Ops.ADD: return self.const_with_shape(self.base.arg * prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is Ops.MUL: return self.const_with_shape(self.base.arg ** prod(self.shape[i] for i in axis), new_shape)
|
||||
if op is Ops.MAX: return self.const_with_shape(self.base.arg, new_shape)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
||||
@@ -213,15 +207,15 @@ class LazyBuffer(MathTrait):
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
||||
def view(self, new_st:ShapeTracker) -> LazyBuffer:
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)):
|
||||
return self.const_with_shape(0, new_st.shape)
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
||||
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
||||
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
||||
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
||||
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(self.st.reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.pad(arg))
|
||||
def expand(self, arg:Tuple[sint, ...]): return self.view(self.st.expand(arg))
|
||||
def permute(self, arg:Tuple[int, ...]): return self.view(self.st.permute(arg))
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(self.st.shrink(arg))
|
||||
def stride(self, arg:Tuple[int, ...]): return self.view(self.st.stride(arg))
|
||||
|
||||
@@ -186,12 +186,12 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
if si.ast.op is Ops.SINK:
|
||||
runner = get_runner(si.outputs[0].device, si.ast)
|
||||
return ExecItem(runner, [si.bufs[x] for x in runner.p.globals], si.metadata)
|
||||
out, arg = si.outputs[0], si.ast.arg
|
||||
out = si.outputs[0]
|
||||
if si.ast.op is Ops.COPY:
|
||||
kernel_type = BufferCopy
|
||||
if hasattr(Device[out.device].allocator, '_transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]:
|
||||
kernel_type = BufferXfer
|
||||
return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs))
|
||||
return ExecItem(kernel_type(out.nbytes, out.device, si.inputs[0].device), list(si.bufs))
|
||||
if si.ast.op is Ops.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs))
|
||||
if si.ast.op is Ops.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs))
|
||||
raise RuntimeError(f"don't know how to lower {si.ast}")
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import FrozenSet, Set, Tuple, List, Dict, Optional, DefaultDict
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, view_left, merge_views
|
||||
from tinygrad.ops import identity_element
|
||||
from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG
|
||||
from tinygrad.dtype import ConstType, ImageDType, dtypes
|
||||
@@ -51,6 +52,7 @@ def is_scheduled(u:UOp) -> bool: return u.op is Ops.VIEW and len(u.src) == 2
|
||||
|
||||
def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# view is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, buffers, cache).view(buf.st)
|
||||
return ret
|
||||
@@ -64,25 +66,24 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, buffers:Dict[UOp, Buffer], cache
|
||||
# hack the underlying buffer too
|
||||
buf.buffer.dtype = dtype
|
||||
buf.buffer.options = None
|
||||
if buf.is_realized:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = None
|
||||
elif buf.op is Ops.ASSIGN:
|
||||
target, new_val = [to_uop(x, ctx, buffers, cache) for x in buf.srcs]
|
||||
ctx.assigns.add(ubuf:=target.base.buf_uop)
|
||||
op = UOp(Ops.ASSIGN, dtype.base, (ubuf, new_val), buf.arg)
|
||||
# base is a VIEW of (BUFFER, (optional) op)
|
||||
# TODO: this is the same underlying Buffer in all schedules, delete_lazy fixes this
|
||||
if buf.is_realized: ret = UOp.new_buffer(buf.device, buf.size, dtype).view(buf.st)
|
||||
# ASSIGN uses the target buffer, otherwise we create a new buffer
|
||||
else:
|
||||
ubuf = UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
buffers[ubuf] = buf.buffer
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs), buf.arg)
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st)
|
||||
if op is not None:
|
||||
src = tuple(to_uop(x, ctx, buffers, cache) for x in buf.srcs)
|
||||
buf_uop = src[0].base.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = UOp(buf.op, dtype if buf.op in GroupOp.Meta else dtype.base, src, buf.arg)
|
||||
ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op.alu(Ops.CONTIGUOUS) if buf.forced_realize else op), buf.st)
|
||||
# keep track of scheduled ops
|
||||
buf.buffer.ref(1)
|
||||
ctx.lazybufs[ubuf] = buf
|
||||
ctx.allbufs[ubuf] = ret
|
||||
ctx.lazybufs[buf_uop] = buf
|
||||
ctx.allbufs[buf_uop] = ret
|
||||
if op.op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[ubuf] = None
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
cache[buf] = ret
|
||||
buffers[ret.buf_uop] = buf.buffer
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
@@ -106,23 +107,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
||||
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(r:UOp, v:UOp, src:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(v.st), unwrap(src.st)
|
||||
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
|
||||
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
||||
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
||||
output_shape = swizzle_st.reduce(r.axis_arg)
|
||||
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
||||
return src.r(r.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
if not (swizzles := [x for x in root.src if x.base is not x]): return None
|
||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||
assert all_same([(x, prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||
new_shape, new_input_shape = swizzle_shapes[0]
|
||||
new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)
|
||||
ret = root.replace(src=new_src)
|
||||
assert all_same([(x.shape, prod(x.src[0].shape)) for x in swizzles]), f"swizzles must have the same size {swizzles}"
|
||||
new_input_st = ShapeTracker.from_shape(swizzles[0].src[0].shape)
|
||||
ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, new_input_st) for x in root.src))
|
||||
# update the ASSIGN offset to match the new shape
|
||||
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),)
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,)
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(swizzles[0].shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
@@ -178,13 +174,16 @@ check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()),
|
||||
to_si = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
# ** fusion
|
||||
|
||||
lazy = PatternMatcher([
|
||||
# gather the metadata for this kernel
|
||||
(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.metadata.add(m) if (m:=ctx.ops_metadata.get(x)) is not None else None),
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x),
|
||||
])
|
||||
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
|
||||
@@ -338,9 +337,29 @@ def _as_const(u:UOp, val:ConstType) -> UOp:
|
||||
st = (base:=ShapeTracker.from_shape(())).reshape((1,)*len(u.shape)).expand(u.shape)
|
||||
return UOp(Ops.VIEW, u.dtype, (u.buf_uop, UOp.const(u.dtype, val)), base).view(st)
|
||||
|
||||
def simplify_reduceop(ctx, reduce:UOp, x:UOp) -> Optional[UOp]:
|
||||
# remove reduce on unmasked const
|
||||
if all_int(x.shape) and x.is_unrealized_unmasked_const():
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return UOp.const(reduce.dtype, ret)
|
||||
return None
|
||||
|
||||
ops_folding = PatternMatcher([
|
||||
# op with size 0 is zero
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: _as_const(base, 0) if base.size == 0 else None),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda ctx,reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_reduceop),
|
||||
# CONST doesn't need COPY
|
||||
(UPat(Ops.COPY, src=(UPat.var("x"),)), lambda ctx,x:x if x.is_unrealized_const() else None),
|
||||
])
|
||||
|
||||
# ** this decides which ops get realized
|
||||
@@ -366,7 +385,7 @@ def fold_img_cast(ctx:Dict[UOp, UOp], xb:UOp, view:UOp, b:UOp, to_cast:UOp, **kw
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> Optional[UOp]:
|
||||
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and uval(x.base).op is not Ops.CONST)
|
||||
new_src = tuple(x.base for x in sink.src if is_scheduled(x.base) and x.base.src[1].op is not Ops.CONST)
|
||||
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
@@ -380,6 +399,8 @@ do_realize = PatternMatcher([
|
||||
(UPatScheduled(Ops.CAST, src=(UPat(Ops.VIEW, src=(UPat.var("xb"), UPat()), name="to_cast"),), dtype=dtypes.float).view(name="view"), fold_img_cast),
|
||||
# realize before COPY or BUFFER_VIEW
|
||||
(UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatScheduled(), UPatScheduled().view()),)), realize),
|
||||
# ASSIGN only needs the buffer
|
||||
(UPat(Ops.ASSIGN, src=(UPat(Ops.VIEW, name="dest"), UPat.var("src")), name="x"), lambda ctx,dest,src,x: x.replace(src=(dest.base.buf_uop, src))),
|
||||
])
|
||||
|
||||
# ** this breaks down realized ops into STOREs and rewrites the ops to LOADs
|
||||
@@ -402,7 +423,7 @@ break_sched = PatternMatcher([
|
||||
# everything else is a VIEW of BUFFER that either realizes or fuses
|
||||
(UPatScheduled(), lambda ctx,b,to_store,base: append_realize(ctx, b, to_store, base) if b in ctx.realizes else append_op(ctx, b, to_store)),
|
||||
# just load realized buffers
|
||||
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, base.dtype, (b, base.st.to_uop()))),
|
||||
(UPatRealized(), lambda ctx,b,base: UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, base.st.to_uop()))),
|
||||
])
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os, json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable
|
||||
import json, pathlib, zipfile, pickle, tarfile, struct, functools, io
|
||||
from typing import Dict, Union, List, Optional, Any, Tuple, Callable, BinaryIO, Iterable, TypeVar
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod, argsort, DEBUG, Timing, CI, unwrap, GlobalCounters, tqdm
|
||||
@@ -35,16 +35,21 @@ safe_dtypes = {"BOOL":dtypes.bool, "I8":dtypes.int8, "U8":dtypes.uint8, "I16":dt
|
||||
"I64":dtypes.int64, "U64":dtypes.uint64, "F16":dtypes.float16, "BF16":dtypes.bfloat16, "F32":dtypes.float32, "F64":dtypes.float64}
|
||||
inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()}
|
||||
|
||||
def safe_load_metadata(fn:Union[Tensor,str]) -> Tuple[Tensor, int, Any]:
|
||||
R = TypeVar('R')
|
||||
def accept_filename(func: Callable[[Tensor], R]) -> Callable[[Union[Tensor, str, pathlib.Path]], R]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(fn: Union[Tensor, str, pathlib.Path]) -> R: return func(Tensor(pathlib.Path(fn)) if not isinstance(fn, Tensor) else fn)
|
||||
return wrapper
|
||||
|
||||
@accept_filename
|
||||
def safe_load_metadata(t:Tensor) -> Tuple[Tensor, int, Dict[str, Any]]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the data, metadata length, and metadata.
|
||||
"""
|
||||
t = fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
json_len = t[0:8].bitcast(dtypes.int64).item()
|
||||
assert isinstance(json_len, int)
|
||||
return t, json_len, json.loads(t[8:8+json_len].data().tobytes())
|
||||
data_start = int.from_bytes(t[0:8].data(), "little") + 8
|
||||
return t, data_start, json.loads(t[8:data_start].data().tobytes())
|
||||
|
||||
def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
def safe_load(fn:Union[Tensor, str, pathlib.Path]) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Loads a .safetensor file from disk, returning the state_dict.
|
||||
|
||||
@@ -52,14 +57,10 @@ def safe_load(fn:Union[Tensor,str]) -> Dict[str, Tensor]:
|
||||
state_dict = nn.state.safe_load("test.safetensor")
|
||||
```
|
||||
"""
|
||||
t, json_len, metadata = safe_load_metadata(fn)
|
||||
ret = {}
|
||||
for k,v in metadata.items():
|
||||
if k == "__metadata__": continue
|
||||
dtype = safe_dtypes[v['dtype']]
|
||||
sz = (v['data_offsets'][1]-v['data_offsets'][0])
|
||||
ret[k] = t[8+json_len+v['data_offsets'][0]:8+json_len+v['data_offsets'][0]+sz].bitcast(dtype).reshape(v['shape'])
|
||||
return ret
|
||||
t, data_start, metadata = safe_load_metadata(fn)
|
||||
data = t[data_start:]
|
||||
return { k: data[v['data_offsets'][0]:v['data_offsets'][1]].bitcast(safe_dtypes[v['dtype']]).reshape(v['shape'])
|
||||
for k, v in metadata.items() if k != "__metadata__" }
|
||||
|
||||
def safe_save(tensors:Dict[str, Tensor], fn:str, metadata:Optional[Dict[str, Any]]=None):
|
||||
"""
|
||||
@@ -157,6 +158,7 @@ def load_state_dict(model, state_dict:Dict[str, Tensor], strict=True, verbose=Tr
|
||||
else: v.replace(state_dict[k].to(v.device)).realize()
|
||||
if consume: del state_dict[k]
|
||||
|
||||
@accept_filename
|
||||
def tar_extract(t: Tensor) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Extracts files from a tar archive and returns them as dictionary of names (keys) and tensors (values).
|
||||
@@ -170,7 +172,8 @@ def tar_extract(t: Tensor) -> Dict[str, Tensor]:
|
||||
|
||||
# torch support!
|
||||
|
||||
def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
@accept_filename
|
||||
def torch_load(t:Tensor) -> Dict[str, Tensor]:
|
||||
"""
|
||||
Loads a torch .pth file from disk.
|
||||
|
||||
@@ -178,8 +181,6 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
state_dict = nn.state.torch_load("test.pth")
|
||||
```
|
||||
"""
|
||||
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
|
||||
offsets: Dict[Union[str, int], int] = {}
|
||||
lens: Dict[Union[str, int], int] = {}
|
||||
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None):
|
||||
@@ -220,8 +221,11 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
return intercept[name] if module_root == "torch" else super().find_class(module, name)
|
||||
def persistent_load(self, pid): return deserialized_objects.get(pid, pid)
|
||||
|
||||
if zipfile.is_zipfile(fn):
|
||||
myzip = zipfile.ZipFile(fn, 'r')
|
||||
fobj = io.BufferedReader(TensorIO(t))
|
||||
def passthrough_reset(v: bool): return fobj.seek(0, 0) or v
|
||||
|
||||
if passthrough_reset(zipfile.is_zipfile(fobj)): # NOTE: passthrough_reset required to support python < 3.14
|
||||
myzip = zipfile.ZipFile(fobj, 'r')
|
||||
base_name = myzip.namelist()[0].split('/', 1)[0]
|
||||
for n in myzip.namelist():
|
||||
if n.startswith(f'{base_name}/data/'):
|
||||
@@ -229,8 +233,8 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
|
||||
with myzip.open(f'{base_name}/data.pkl') as myfile:
|
||||
return TorchPickle(myfile).load()
|
||||
elif tarfile.is_tarfile(fn):
|
||||
with tarfile.open(fn, "r") as tar:
|
||||
elif passthrough_reset(tarfile.is_tarfile(fobj)): # NOTE: passthrough_reset required to support python < 3.11
|
||||
with tarfile.open(fileobj=fobj, mode="r") as tar:
|
||||
storages_offset = tar.getmember('storages').offset_data
|
||||
f = unwrap(tar.extractfile('storages'))
|
||||
for i in range(TorchPickle(f).load()): # num_storages
|
||||
@@ -245,14 +249,13 @@ def torch_load(fn:str) -> Dict[str, Tensor]:
|
||||
deserialized_objects[str(key)] = _rebuild_tensor_v2((None, storage_type, storage_id, None, -1), storage_offset, size, stride)
|
||||
return {k:v.tensor if isinstance(v, Parameter) else v for k,v in TorchPickle(unwrap(tar.extractfile('pickle'))).load().items()}
|
||||
else:
|
||||
with open(fn, "rb") as f:
|
||||
pkl = TorchPickle(f)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell()
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
base_offset += 8 + lens[i]
|
||||
f.seek(rwd)
|
||||
return TorchPickle(f).load()
|
||||
pkl = TorchPickle(fobj)
|
||||
_, _, _, rwd, _, ids, base_offset = pkl.load(), pkl.load(), pkl.load(), fobj.tell(), pkl.load(), pkl.load(), fobj.tell()
|
||||
for i in ids:
|
||||
offsets[i] = base_offset + 8
|
||||
base_offset += 8 + lens[i]
|
||||
fobj.seek(rwd)
|
||||
return TorchPickle(fobj).load()
|
||||
|
||||
def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
"""
|
||||
@@ -287,6 +290,7 @@ def ggml_data_to_tensor(t: Tensor, n: int, ggml_type: int) -> Tensor:
|
||||
return d * (xl.bitwise_or(xh).bitcast(dtypes.int8) - 32).flatten(-2) * scales
|
||||
raise ValueError(f"GGML type '{ggml_type}' is not supported!")
|
||||
|
||||
@accept_filename
|
||||
def gguf_load(tensor: Tensor) -> Tuple[Dict, Dict[str, Tensor]]:
|
||||
"""
|
||||
Loads a gguf file from a tensor.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, DefaultDict, Literal, get_args
|
||||
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle, pathlib, inspect, weakref
|
||||
from enum import auto, IntEnum, Enum
|
||||
from dataclasses import dataclass, field
|
||||
@@ -25,8 +25,7 @@ class SimpleMathTrait:
|
||||
def _binop(self, op, x, reverse): return self.ufix(x).alu(op, self) if reverse else self.alu(op, self.ufix(x))
|
||||
def logical_not(self): return self.ne(True)
|
||||
def neg(self):
|
||||
dtype: Optional[DType] = getattr(self, 'dtype', None)
|
||||
assert dtype is not None, "MathTraits __neg__ requires a dtype"
|
||||
if (dtype:=getattr(self, 'dtype')) is None: raise TypeError(f"MathTraits __neg__ requires a dtype, {self=}")
|
||||
return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1)
|
||||
def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse)
|
||||
@@ -162,6 +161,9 @@ class GroupOp:
|
||||
# do not preserve f(0) = 0
|
||||
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
|
||||
|
||||
# some BUFFER ops can be processed with only a view
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
@@ -259,6 +261,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size
|
||||
@property
|
||||
def nbytes(self) -> int: return self.size*self.dtype.itemsize
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
@@ -288,6 +292,14 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
assert ret.op is Ops.VIEW, f"st_arg trying to return {ret}"
|
||||
return ret.arg
|
||||
@property
|
||||
def const_arg(self) -> ConstType:
|
||||
match self.base.op:
|
||||
case Ops.CONST: ret = self.base.arg
|
||||
case Ops.VIEW: ret = self.base.src[1].const_arg
|
||||
case op: raise AssertionError(f"const_arg called on {op}")
|
||||
assert isinstance(ret, get_args(ConstType)), f"const_arg trying to return {ret}"
|
||||
return ret
|
||||
@property
|
||||
def axis_arg(self) -> Tuple[int, ...]:
|
||||
assert self.op in {Ops.REDUCE_AXIS, Ops.WMMA}, f"axis_arg called on {self.op}"
|
||||
ret = self.arg[1] if self.op is Ops.REDUCE_AXIS else self.arg[7]
|
||||
@@ -323,11 +335,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(Ops.VCONST if isinstance(b, tuple) else Ops.CONST, dtype, arg=dtypes.as_const(b, dtype))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(Ops.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
def range(dtype:DType, start:sint, end:sint, idx:int): return UOp(Ops.RANGE, dtype=dtype, src=(sint_to_uop(start), sint_to_uop(end)), arg=idx)
|
||||
def r(self, op:Ops, axis:Tuple[int, ...]):
|
||||
axis = tuple(sorted([x for x in axis if resolve(self.shape[x] != 1)]))
|
||||
return self if len(axis) == 0 else UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x), None if self.st is None or self.st.contiguous else self.st)
|
||||
def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,))
|
||||
|
||||
# *** from LazyBuffer ***
|
||||
@@ -336,17 +348,19 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def const_with_shape(dtype:DType, val:ConstLike, shape:Tuple[sint,...]) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return UOp(Ops.VALID, dtypes.bool, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),)).where(UOp.const(dtype, val), 0)
|
||||
def is_unrealized_const(self): return (s:=self.base).op is Ops.VIEW and len(s.src) == 2 and s.src[1].op is Ops.CONST
|
||||
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in unwrap(self.st).views)
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@property
|
||||
def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
assert self.st is not None and self.base.st is not None, f"must have shape {self}"
|
||||
if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
# instant folding rules
|
||||
if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)
|
||||
if new_st.contiguous and self.base.st.shape == new_st.shape: return self.base
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return ret
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self.view(unwrap(self.st).pad(arg))
|
||||
@@ -428,9 +442,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax
|
||||
if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals)
|
||||
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
||||
if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
||||
if self.op is Ops.IDIV:
|
||||
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin)
|
||||
# don't know exact bounds, but know the sign
|
||||
if (s0_vmax <= 0 and s1_vmin < 0) or (s0_vmin >= 0 and s1_vmin > 0): return 0, dtypes.max(self.dtype)
|
||||
if (s0_vmax <= 0 and s1_vmin > 0) or (s0_vmin >= 0 and s1_vmin < 0): return dtypes.min(self.dtype), 0
|
||||
if self.op is Ops.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax)
|
||||
if self.op is Ops.CMPLT: return (s0_vmax<s1_vmin, s0_vmin<s1_vmax)
|
||||
if self.op is Ops.CMPNE: return ((s0_vmax < s1_vmin) or (s1_vmax < s0_vmin), not (s0_vmin == s0_vmax == s1_vmin == s1_vmax))
|
||||
@@ -445,7 +463,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
if self.op is Ops.BIND: return self.src[0]._min_max # ignore the bound value
|
||||
if self.op in {Ops.EXPAND, Ops.VECTORIZE}: return min(x.vmin for x in self.src), max(x.vmax for x in self.src)
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is Ops.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else self.arg[1].vmax
|
||||
if self.op is Ops.CONST: return self.arg, self.arg
|
||||
if self.op is Ops.VCONST: return (min(self.arg), max(self.arg))
|
||||
return dtypes.min(self.dtype), dtypes.max(self.dtype)
|
||||
@@ -483,7 +501,7 @@ python_alu: Dict[Ops, Callable] = {
|
||||
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
|
||||
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
|
||||
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
|
||||
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,
|
||||
Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True):
|
||||
@@ -1033,7 +1051,7 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp):
|
||||
if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2
|
||||
if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1
|
||||
|
||||
def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
||||
|
||||
symbolic_simple = PatternMatcher([
|
||||
# ** self folding **
|
||||
@@ -1046,6 +1064,8 @@ symbolic_simple = PatternMatcher([
|
||||
((UPat.var("x") * UPat.var("x2")) / UPat.var("x2"), lambda x,x2: x), # (x*x2)/x2 -> x
|
||||
((UPat.var() % UPat.var("y")).named("base") % UPat.var("y"), lambda base,y: base), # (x%y)%y = -> x%y (rewritten with base for speed)
|
||||
(UPat.var("x")%UPat.cvar("c")+(UPat.var("x")//UPat.cvar("c"))*UPat.cvar("c"), lambda x,c: x), # (x%c)+(x//c)*c = x
|
||||
((UPat.var("x")//UPat.cvar("c1"))*UPat.cvar("c3")+UPat.var("x")%UPat.cvar("c1")*UPat.cvar("c2"),
|
||||
lambda x,c1,c2,c3: x*c2 if c1.arg*c2.arg==c3.arg else None), # (x%c1)*c2+(x//c1)*c3 = x*c2 if c1*c2==c3
|
||||
(UPat.var("x", dtype=dtypes.bool) & UPat.cvar("c", vec=False), lambda x,c: x if c.arg else c),
|
||||
(UPat.var("x", dtype=dtypes.bool) | UPat.cvar("c", vec=False), lambda x,c: c if c.arg else x),
|
||||
(UPat.var("x").maximum(UPat.var("x")), lambda x: x),
|
||||
|
||||
@@ -67,8 +67,9 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.short, dtypes.ushort), name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]}&0xFFFF)"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"bitcast<{type_map[x.dtype]}>({ctx[x.src[0]]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), lambda ctx,b,v,g: f"select({ctx[v]}, {ctx[b]}, {ctx[g]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx[b]),
|
||||
(UPat(Ops.LOAD, src=(UPat.var("b"), UPat.var('v'), UPat.var("g"))), \
|
||||
lambda ctx,b,v,g: f"select({ctx[v]}, {ctx.render_load(ctx[b], b.src[0].dtype)}, {ctx[g]})"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('b'),), allow_any_len=True), lambda ctx, b: ctx.render_load(ctx[b], b.src[0].dtype)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))),
|
||||
lambda ctx,buf,idx: f"{ctx[buf]}[{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]}]"),
|
||||
(UPat(Ops.STORE, src=(UPat.var('b'), UPat.var("v"))),lambda ctx,b,v:\
|
||||
@@ -81,7 +82,8 @@ class WGSLRenderer(CStyleLanguage):
|
||||
|
||||
def render_cast(self, dt:DType, val: str) -> str: return f"{self.type_map[dt]}({val})"
|
||||
def render_dtype(self, dt:DType, mutable=True) -> str: return "var"
|
||||
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if rw and dt.itemsize < 4 else buffer_map[dt.base]}"
|
||||
def render_load(self, x:str, dt:DType) -> str: return f"atomicLoad(&{x})" if dt.itemsize < 4 else x
|
||||
def render_buf_dt(self, dt:DType, rw=True) -> str: return f"{f'atomic<{buffer_map[dt]}>' if dt.itemsize < 4 else buffer_map[dt.base]}"
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
local_size = [num for _, num in sorted([u.arg for u in uops if u.op is Ops.SPECIAL and u.arg[0][0] == 'l'], key=lambda x: x[0])]
|
||||
if not local_size: local_size = [1]
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import collections, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWQueue, HCQArgsState, BumpAllocator
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, Device
|
||||
from tinygrad import Variable, dtypes
|
||||
from tinygrad.ops import sint, Variable as VariableT
|
||||
from tinygrad.ops import Variable as VariableT
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
@@ -13,6 +13,14 @@ class HCQGraph(MultiGraphRunner):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.devices = list(set(cast(HCQCompiled, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs]))
|
||||
|
||||
# Replace input buffers with variables.
|
||||
self.hcq_bufs = [[cast(Buffer, x)._buf for x in ji.bufs] for ji in jit_cache]
|
||||
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
|
||||
|
||||
for (j,i), input_idx in self.input_replace.items():
|
||||
x = self.input_replace_to_var.setdefault((j,i), Variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in jit_cache:
|
||||
@@ -23,11 +31,11 @@ class HCQGraph(MultiGraphRunner):
|
||||
# Fill initial arguments.
|
||||
self.ji_args: Dict[int, HCQArgsState] = {}
|
||||
|
||||
kargs_ptrs: Dict[Compiled, int] = {dev:buf.va_addr for dev,buf in self.kernargs_bufs.items()}
|
||||
kargs_alloc: Dict[Compiled, BumpAllocator] = {dev:BumpAllocator(buf.size, start=cast(int, buf.va_addr)) for dev,buf in self.kernargs_bufs.items()}
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
kargs_ptrs[ji.prg.dev] = (kargs_ptr:=kargs_ptrs[ji.prg.dev]) + round_up(ji.prg._prg.kernargs_alloc_size, 16)
|
||||
self.ji_args[j] = ji.prg._prg.fill_kernargs([cast(Buffer, b)._buf for b in ji.bufs], [var_vals[v] for v in ji.prg.p.vars], kargs_ptr)
|
||||
|
||||
self.ji_args[j] = ji.prg._prg.fill_kernargs(self.hcq_bufs[j], ji.prg.p.vars, kargs_alloc[ji.prg.dev].alloc(ji.prg._prg.kernargs_alloc_size, 16))
|
||||
|
||||
# Schedule Dependencies.
|
||||
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
||||
@@ -53,8 +61,14 @@ class HCQGraph(MultiGraphRunner):
|
||||
for dev, queue in self.comp_queues.items(): dev_access[queue].add(dev)
|
||||
|
||||
for j,ji in enumerate(jit_cache):
|
||||
enqueue_dev = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if is_exec_prg else self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
enqueue_dev: HCQCompiled = ji.prg.dev if (is_exec_prg:=isinstance(ji.prg, CompiledRunner)) else Device[ji.bufs[1].device] #type:ignore
|
||||
|
||||
if is_exec_prg:
|
||||
enqueue_queue = self.comp_queues[enqueue_dev]
|
||||
else:
|
||||
assert (enqueue_dev.hw_copy_queue_t is not None), "device must implement a copy queue"
|
||||
enqueue_queue = self.copy_queues.setdefault(enqueue_dev, enqueue_dev.hw_copy_queue_t())
|
||||
|
||||
out_signal = self.signals.setdefault(enqueue_queue, enqueue_dev.signal_t(value=0))
|
||||
|
||||
# Get dependencies based on input and output buffers.
|
||||
@@ -101,7 +115,6 @@ class HCQGraph(MultiGraphRunner):
|
||||
last_j[enqueue_queue] = j
|
||||
|
||||
# Build hardware queues.
|
||||
self.input_replace_to_var: Dict[Tuple[int, int], VariableT] = {}
|
||||
self.copy_to_devs: Dict[HCQCompiled, Set[HCQCompiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
# Create variable timeline signals for each device.
|
||||
@@ -128,8 +141,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
cast(HCQAllocator, Device[src.device].allocator).map(dest._buf)
|
||||
|
||||
# TODO: For now sints is only for copies, should refactor to support exec as well.
|
||||
enqueue_queue.copy(self._buf_addr_as_sint(j, 0, dest._buf), self._buf_addr_as_sint(j, 1, src._buf), dest.nbytes)
|
||||
enqueue_queue.copy(self.hcq_bufs[j][0].va_addr, self.hcq_bufs[j][1].va_addr, dest.nbytes)
|
||||
self.copy_to_devs[cast(HCQCompiled, Device[dest.device])].add(cast(HCQCompiled, Device[src.device]))
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
@@ -161,12 +173,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
**{sig.base_addr: dev.timeline_signal.base_addr for dev, sig in self.virt_timeline_signals.items()}}
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
if (var:=self.input_replace_to_var.get((j,i))) is not None: hcq_var_vals[var] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
else: self.ji_args[j].update_buffer(i, input_rawbuffers[input_idx]._buf)
|
||||
|
||||
# Update var_vals
|
||||
for j, i, v in self.updated_vars(var_vals): self.ji_args[j].update_var(i, v)
|
||||
for (j,i),input_idx in self.input_replace.items(): hcq_var_vals[self.input_replace_to_var.get((j,i))] = input_rawbuffers[input_idx]._buf.va_addr
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].submit(dev, hcq_var_vals)
|
||||
@@ -191,10 +198,6 @@ class HCQGraph(MultiGraphRunner):
|
||||
(b_st,_), (b_en,_), b_dev, _, b_is_cp, _, _ = self.prof_records[x]
|
||||
dev.dep_prof_records += [(timestamps[b_st], timestamps[b_en], b_dev, b_is_cp, timestamps[st], timestamps[en], dev, is_cp)]
|
||||
|
||||
def _buf_addr_as_sint(self, j:int, i:int, buf:HCQBuffer) -> sint:
|
||||
if (j, i) not in self.input_replace: return buf.va_addr
|
||||
return self.input_replace_to_var.setdefault((j, i), Variable(f"input_{j}_{i}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: self.last_timeline[dev][0].wait(self.last_timeline[dev][1])
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, List, Any, Optional
|
||||
from typing import Tuple, List, Any, Optional, cast
|
||||
import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, array, contextlib, sys
|
||||
assert sys.platform != 'win32'
|
||||
from dataclasses import dataclass
|
||||
@@ -80,6 +80,8 @@ class AMDComputeQueue(HWQueue):
|
||||
return self
|
||||
|
||||
def exec(self, prg:AMDProgram, args_state:CLikeArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
self.bind_args_state(args_state)
|
||||
|
||||
self.acquire_mem(gli=0, gl2=0)
|
||||
|
||||
if prg.enable_private_segment_sgpr:
|
||||
@@ -262,15 +264,15 @@ class AMDAllocator(HCQAllocator['AMDDevice']):
|
||||
def __init__(self, dev:AMDDevice): super().__init__(dev, batch_size=SDMA_MAX_COPY_SIZE)
|
||||
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.host: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
|
||||
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
return self.dev._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
|
||||
if options.host: return self.dev._gpu_alloc(size, host=True)
|
||||
if options.cpu_access and options.uncached: return self.dev._gpu_alloc(size, uncached=True)
|
||||
return self.dev._gpu_alloc(size, cpu_access=options.cpu_access)
|
||||
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if hasattr(buf, '_base') else buf)
|
||||
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if buf._base is not None else buf)
|
||||
|
||||
MAP_FIXED, MAP_NORESERVE = 0x10, 0x400
|
||||
|
||||
@@ -289,45 +291,48 @@ class AMDDevice(HCQCompiled):
|
||||
signals_pool:List[int] = []
|
||||
gpus:List[pathlib.Path] = []
|
||||
|
||||
def _gpu_map(self, mem):
|
||||
if self.gpu_id in getattr(mem, "mapped_gpu_ids", []): return
|
||||
mem.__setattr__("mapped_gpu_ids", getattr(mem, "mapped_gpu_ids", []) + [self.gpu_id])
|
||||
c_gpus = (ctypes.c_int32 * len(mem.mapped_gpu_ids))(*mem.mapped_gpu_ids)
|
||||
stm = kfd.AMDKFD_IOC_MAP_MEMORY_TO_GPU(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(c_gpus),
|
||||
n_devices=len(mem.mapped_gpu_ids))
|
||||
assert stm.n_success == len(mem.mapped_gpu_ids)
|
||||
def _gpu_map(self, mem:HCQBuffer):
|
||||
if self.gpu_id in getattr(mem.meta, "mapped_gpu_ids", []): return
|
||||
mem.meta.__setattr__("mapped_gpu_ids", getattr(mem.meta, "mapped_gpu_ids", []) + [self.gpu_id])
|
||||
c_gpus = (ctypes.c_int32 * len(mem.meta.mapped_gpu_ids))(*mem.meta.mapped_gpu_ids)
|
||||
stm = kfd.AMDKFD_IOC_MAP_MEMORY_TO_GPU(self.kfd, handle=mem.meta.handle, device_ids_array_ptr=ctypes.addressof(c_gpus),
|
||||
n_devices=len(mem.meta.mapped_gpu_ids))
|
||||
assert stm.n_success == len(mem.meta.mapped_gpu_ids)
|
||||
|
||||
def _gpu_alloc(self, size:int, flags:int, uncached=False, public=False, map_to_gpu=True):
|
||||
flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
|
||||
if uncached: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_COHERENT | kfd.KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED
|
||||
if public: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_PUBLIC
|
||||
if flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR:
|
||||
buf = addr = libc.mmap(0, size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
else:
|
||||
buf, addr = 0, libc.mmap(0, size, 0, mmap.MAP_PRIVATE|mmap.MAP_ANONYMOUS|MAP_NORESERVE, -1, 0)
|
||||
def _gpu_alloc(self, size:int, host=False, uncached=False, cpu_access=False) -> HCQBuffer:
|
||||
flags = kfd.KFD_IOC_ALLOC_MEM_FLAGS_WRITABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_EXECUTABLE | kfd.KFD_IOC_ALLOC_MEM_FLAGS_NO_SUBSTITUTE
|
||||
|
||||
if uncached: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_COHERENT | kfd.KFD_IOC_ALLOC_MEM_FLAGS_UNCACHED | kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT
|
||||
else: flags |= (kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR if host else kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
|
||||
if cpu_access or host: flags |= kfd.KFD_IOC_ALLOC_MEM_FLAGS_PUBLIC
|
||||
|
||||
if host: buf = addr = libc.mmap(0, size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
|
||||
else: buf, addr = 0, libc.mmap(0, size, 0, mmap.MAP_PRIVATE | mmap.MAP_ANONYMOUS | MAP_NORESERVE, -1, 0)
|
||||
assert addr != 0xffffffffffffffff
|
||||
|
||||
try: mem = kfd.AMDKFD_IOC_ALLOC_MEMORY_OF_GPU(self.kfd, va_addr=addr, size=size, base=addr, length=size, gpu_id=self.gpu_id,
|
||||
flags=flags, mmap_offset=buf)
|
||||
flags=flags, mmap_offset=buf, cpu_addr=addr)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EINVAL and (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) and public:
|
||||
if e.errno == errno.EINVAL and (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) and cpu_access:
|
||||
raise MemoryError("Cannot allocate host-visible VRAM. Ensure the resizable BAR option is enabled on your system.") from e
|
||||
if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory: no memory is available.") from e
|
||||
raise
|
||||
|
||||
if not (flags & kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR):
|
||||
buf = libc.mmap(mem.va_addr, mem.size, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|MAP_FIXED, self.drm_fd, mem.mmap_offset)
|
||||
if not host:
|
||||
buf = libc.mmap(mem.va_addr, mem.size, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED | MAP_FIXED, self.drm_fd, mem.mmap_offset)
|
||||
assert addr == buf == mem.va_addr
|
||||
if map_to_gpu: self._gpu_map(mem)
|
||||
return mem
|
||||
|
||||
def _gpu_free(self, mem):
|
||||
if len(gpus:=getattr(mem, "mapped_gpu_ids", [])):
|
||||
self._gpu_map(hcqbuf:=HCQBuffer(mem.va_addr, mem.size, meta=mem))
|
||||
return hcqbuf
|
||||
|
||||
def _gpu_free(self, mem:HCQBuffer):
|
||||
if len(gpus:=getattr(mem.meta, "mapped_gpu_ids", [])):
|
||||
c_gpus = (ctypes.c_int32 * len(gpus))(*gpus)
|
||||
stm = kfd.AMDKFD_IOC_UNMAP_MEMORY_FROM_GPU(self.kfd, handle=mem.handle, device_ids_array_ptr=ctypes.addressof(c_gpus), n_devices=len(gpus))
|
||||
stm = kfd.AMDKFD_IOC_UNMAP_MEMORY_FROM_GPU(self.kfd, handle=mem.meta.handle, device_ids_array_ptr=ctypes.addressof(c_gpus), n_devices=len(gpus))
|
||||
assert stm.n_success == len(gpus)
|
||||
libc.munmap(mem.va_addr, mem.size)
|
||||
kfd.AMDKFD_IOC_FREE_MEMORY_OF_GPU(self.kfd, handle=mem.handle)
|
||||
kfd.AMDKFD_IOC_FREE_MEMORY_OF_GPU(self.kfd, handle=mem.meta.handle)
|
||||
|
||||
def __init__(self, device:str=""):
|
||||
if AMDDevice.kfd == -1:
|
||||
@@ -350,10 +355,10 @@ class AMDDevice(HCQCompiled):
|
||||
kfd.AMDKFD_IOC_ACQUIRE_VM(AMDDevice.kfd, drm_fd=self.drm_fd, gpu_id=self.gpu_id)
|
||||
|
||||
if AMDDevice.event_page is None:
|
||||
AMDDevice.signals_page = self._gpu_alloc(16 * 65536, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
AMDDevice.event_page = self._gpu_alloc(0x8000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
AMDDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
|
||||
AMDDevice.event_page = self._gpu_alloc(0x8000, uncached=True)
|
||||
AMDDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, AMDDevice.signals_page.size, 16)]
|
||||
kfd.AMDKFD_IOC_CREATE_EVENT(AMDDevice.kfd, event_page_offset=AMDDevice.event_page.handle)
|
||||
kfd.AMDKFD_IOC_CREATE_EVENT(AMDDevice.kfd, event_page_offset=AMDDevice.event_page.meta.handle)
|
||||
else:
|
||||
self._gpu_map(AMDDevice.signals_page)
|
||||
self._gpu_map(AMDDevice.event_page)
|
||||
@@ -365,7 +370,7 @@ class AMDDevice(HCQCompiled):
|
||||
# <gfx103 requires alignment of 1024, >=gfx11 requires 256
|
||||
wave_scratch_len = round_up(((max_wave_id + 1) * self.max_private_segment_size), 256 if self.target >= 110000 else 1024)
|
||||
self.scratch_len = (max_cu_id + 1) * self.properties['max_slots_scratch_cu'] * wave_scratch_len
|
||||
self.scratch = self._gpu_alloc(self.scratch_len, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM)
|
||||
self.scratch = self._gpu_alloc(self.scratch_len)
|
||||
self.has_scratch_base_registers = self.target >= 110000
|
||||
engines = self.properties['array_count'] // self.properties['simd_arrays_per_engine']
|
||||
waves = wave_scratch_len // (256 if self.target >= 110000 else 1024)
|
||||
@@ -396,11 +401,10 @@ class AMDDevice(HCQCompiled):
|
||||
AMDSignal, AMDComputeQueue, AMDCopyQueue)
|
||||
|
||||
def _alloc_queue(self, queue_type, ring_size, ctx_save_restore_size=None, eop_buffer_size=None, ctl_stack_size=0) -> AMDQueueDesc:
|
||||
gart = self._gpu_alloc(0x1000, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
ring = self._gpu_alloc(ring_size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_GTT, uncached=True)
|
||||
cwsr_ctx = self._gpu_alloc(round_up(ctx_save_restore_size + self.debug_memory_size, mmap.PAGESIZE),
|
||||
kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) if ctx_save_restore_size else None
|
||||
eop_buffer = self._gpu_alloc(eop_buffer_size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM) if eop_buffer_size else None
|
||||
gart = self._gpu_alloc(0x1000, uncached=True)
|
||||
ring = self._gpu_alloc(ring_size, uncached=True)
|
||||
cwsr_ctx = self._gpu_alloc(round_up(ctx_save_restore_size + self.debug_memory_size, mmap.PAGESIZE)) if ctx_save_restore_size else None
|
||||
eop_buffer = self._gpu_alloc(eop_buffer_size) if eop_buffer_size else None
|
||||
queue = kfd.AMDKFD_IOC_CREATE_QUEUE(AMDDevice.kfd, ring_base_address=ring.va_addr, ring_size=ring.size, gpu_id=self.gpu_id,
|
||||
queue_type=queue_type, queue_percentage=kfd.KFD_MAX_QUEUE_PERCENTAGE, queue_priority=kfd.KFD_MAX_QUEUE_PRIORITY,
|
||||
eop_buffer_address=eop_buffer.va_addr if eop_buffer else 0, eop_buffer_size=eop_buffer.size if eop_buffer else 0, ctl_stack_size=ctl_stack_size,
|
||||
@@ -411,7 +415,7 @@ class AMDDevice(HCQCompiled):
|
||||
self.doorbells_base = queue.doorbell_offset & (~0x1fff) # doorbell is two pages
|
||||
self.doorbells = libc.mmap(0, 0x2000, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED, AMDDevice.kfd, self.doorbells_base)
|
||||
|
||||
return AMDQueueDesc(ring=to_mv(ring.va_addr, ring_size).cast("I"),
|
||||
return AMDQueueDesc(ring=to_mv(cast(int, ring.va_addr), ring_size).cast("I"),
|
||||
read_ptr=to_mv(queue.read_pointer_address, 8).cast("Q"), write_ptr=to_mv(queue.write_pointer_address, 8).cast("Q"),
|
||||
doorbell=to_mv(self.doorbells + queue.doorbell_offset - self.doorbells_base, 8).cast("Q"))
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, pathlib, struct, ctypes, tempfile, functools
|
||||
from typing import List, Any, Union, Tuple, cast
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T
|
||||
from tinygrad.helpers import prod, to_mv, getenv, round_up, _cache_dir, T, init_c_struct_t
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
@@ -45,10 +45,7 @@ def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id)
|
||||
|
||||
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
|
||||
|
||||
def to_struct(*t: int, _type: type = ctypes.c_ulong):
|
||||
class Struct(ctypes.Structure): pass
|
||||
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
|
||||
return Struct(*t)
|
||||
def to_struct(*t: int, _type: type = ctypes.c_ulong): return init_c_struct_t(tuple([(f"field{i}", _type) for i in range(len(t))]))(*t)
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
msg(cbuf, "waitUntilCompleted")
|
||||
@@ -112,9 +109,8 @@ class MetalProgram:
|
||||
if lib[:4] == b"MTLB":
|
||||
# binary metal library
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
error_library_creation = objc_instance()
|
||||
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
|
||||
error_check(error_library_creation)
|
||||
self.library = msg(self.dev.sysdevice, "newLibraryWithData:error:", data, ctypes.byref(error_lib:=objc_instance()), restype=objc_instance)
|
||||
error_check(error_lib)
|
||||
else:
|
||||
# metal source. rely on OS caching
|
||||
try: self.library = metal_src_to_library(self.dev, lib.decode())
|
||||
@@ -137,7 +133,7 @@ class MetalProgram:
|
||||
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
||||
msg(encoder, "setComputePipelineState:", self.pipeline_state)
|
||||
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
|
||||
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
||||
for i,a in enumerate(vals, start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
||||
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
||||
msg(encoder, "endEncoding")
|
||||
msg(command_buffer, "commit")
|
||||
@@ -178,9 +174,7 @@ class MetalAllocator(LRUAllocator):
|
||||
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
||||
def _as_buffer(self, src:MetalBuffer) -> memoryview:
|
||||
self.dev.synchronize()
|
||||
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
|
||||
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
||||
return memoryview(array).cast("B")[src.offset:]
|
||||
return to_mv(cast(int, msg(src.buf, "contents", restype=objc_id).value), src.size + src.offset)[src.offset:]
|
||||
def _copyin(self, dest:MetalBuffer, src:memoryview): self._as_buffer(dest)[:] = src
|
||||
def _copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self._as_buffer(src)
|
||||
def _offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, sys
|
||||
assert sys.platform != 'win32'
|
||||
from typing import Tuple, List, Any, cast, Union, Dict, Type, Optional
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWQueue, CLikeArgsState, HCQProgram, HCQSignal, BumpAllocator
|
||||
from tinygrad.ops import sint
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
|
||||
@@ -71,8 +71,6 @@ def make_qmd_struct_type():
|
||||
qmd_struct_t = make_qmd_struct_type()
|
||||
assert ctypes.sizeof(qmd_struct_t) == 0x40 * 4
|
||||
|
||||
def nvmethod(subc, mthd, size, typ=2): return (typ << 28) | (size << 16) | (subc << 13) | (mthd >> 2)
|
||||
|
||||
class NVSignal(HCQSignal):
|
||||
def __init__(self, base_addr:Optional[int]=None, **kwargs):
|
||||
super().__init__(NVDevice.signals_pool.pop() if base_addr is None else base_addr, **kwargs, timestamp_divider=1000, value_off=0, timestamp_off=8)
|
||||
@@ -88,18 +86,19 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
def __del__(self):
|
||||
if self.binded_device is not None: self.binded_device.allocator.free(self.hw_page, self.hw_page.size, BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
def nvm(self, subchannel, mthd, *args, typ=2): self.q((typ << 28) | (len(args) << 16) | (subchannel << 13) | (mthd >> 2), *args)
|
||||
|
||||
def setup(self, compute_class=None, copy_class=None, local_mem_window=None, shared_mem_window=None, local_mem=None, local_mem_tpc_bytes=None):
|
||||
if compute_class: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_OBJECT, 1), compute_class)
|
||||
if copy_class: self.q(nvmethod(4, nv_gpu.NVC6C0_SET_OBJECT, 1), copy_class)
|
||||
if local_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_WINDOW_A, 2), *data64(local_mem_window))
|
||||
if shared_mem_window: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, 2), *data64(shared_mem_window))
|
||||
if local_mem: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, 2), *data64(local_mem))
|
||||
if local_mem_tpc_bytes: self.q(nvmethod(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, 3), *data64(local_mem_tpc_bytes), 0xff)
|
||||
if compute_class: self.nvm(1, nv_gpu.NVC6C0_SET_OBJECT, compute_class)
|
||||
if copy_class: self.nvm(4, nv_gpu.NVC6C0_SET_OBJECT, copy_class)
|
||||
if local_mem_window: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_WINDOW_A, *data64(local_mem_window))
|
||||
if shared_mem_window: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_SHARED_MEMORY_WINDOW_A, *data64(shared_mem_window))
|
||||
if local_mem: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_A, *data64(local_mem))
|
||||
if local_mem_tpc_bytes: self.nvm(1, nv_gpu.NVC6C0_SET_SHADER_LOCAL_MEMORY_NON_THROTTLED_A, *data64(local_mem_tpc_bytes), 0xff)
|
||||
return self
|
||||
|
||||
def wait(self, signal:NVSignal, value:sint=0):
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
|
||||
(3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value), (3 << 0) | (1 << 24)) # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
@@ -117,14 +116,9 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
def _submit_to_gpfifo(self, dev:NVDevice, gpfifo:GPFifo):
|
||||
if dev == self.binded_device: cmdq_addr = self.hw_page.va_addr
|
||||
else:
|
||||
if dev.cmdq_wptr + len(self._q) * 4 > dev.cmdq_page.size:
|
||||
assert (gpfifo.ring[gpfifo.controls.GPGet] & 0xFFFFFFFFFC) >= dev.cmdq_page.va_addr + len(self._q) * 4 or \
|
||||
gpfifo.controls.GPGet == gpfifo.controls.GPPut, "cmdq overrun"
|
||||
dev.cmdq_wptr = 0
|
||||
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self._q)] = array.array('I', self._q)
|
||||
cmdq_addr = dev.cmdq_page.va_addr+dev.cmdq_wptr
|
||||
dev.cmdq_wptr += len(self._q) * 4
|
||||
cmdq_addr = dev.cmdq_allocator.alloc(len(self._q) * 4)
|
||||
cmdq_wptr = (cmdq_addr - dev.cmdq_page.va_addr) // 4
|
||||
dev.cmdq[cmdq_wptr : cmdq_wptr + len(self._q)] = array.array('I', self._q)
|
||||
|
||||
gpfifo.ring[gpfifo.put_value % gpfifo.entries_count] = (cmdq_addr//4 << 2) | (len(self._q) << 42) | (1 << 41)
|
||||
gpfifo.controls.GPPut = (gpfifo.put_value + 1) % gpfifo.entries_count
|
||||
@@ -133,11 +127,13 @@ class NVCommandQueue(HWQueue[NVSignal, 'NVDevice', 'NVProgram', 'NVArgsState']):
|
||||
|
||||
class NVComputeQueue(NVCommandQueue):
|
||||
def memory_barrier(self):
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0))
|
||||
self.nvm(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, (1 << 12) | (1 << 4) | (1 << 0))
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
def exec(self, prg:NVProgram, args_state:NVArgsState, global_size:Tuple[sint, ...], local_size:Tuple[sint, ...]):
|
||||
self.bind_args_state(args_state)
|
||||
|
||||
ctypes.memmove(qmd_addr:=(args_state.ptr + round_up(prg.constbufs[0][1], 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
|
||||
assert qmd_addr < (1 << 40), f"large qmd addr {qmd_addr:x}"
|
||||
|
||||
@@ -148,8 +144,8 @@ class NVComputeQueue(NVCommandQueue):
|
||||
qmd.constant_buffer_addr_upper_0, qmd.constant_buffer_addr_lower_0 = data64(args_state.ptr)
|
||||
|
||||
if self.active_qmd is None:
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_PCAS_A, 0x1), qmd_addr >> 8)
|
||||
self.q(nvmethod(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 0x1), 9)
|
||||
self.nvm(1, nv_gpu.NVC6C0_SEND_PCAS_A, qmd_addr >> 8)
|
||||
self.nvm(1, nv_gpu.NVC6C0_SEND_SIGNALING_PCAS2_B, 9)
|
||||
else:
|
||||
self.active_qmd.dependent_qmd0_pointer = qmd_addr >> 8
|
||||
self.active_qmd.dependent_qmd0_action = 1
|
||||
@@ -168,9 +164,9 @@ class NVComputeQueue(NVCommandQueue):
|
||||
self.bind_sints(value, struct=self.active_qmd, start_field=f'release{i}_payload', fmt='Q')
|
||||
return self
|
||||
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *data64_le(signal.value_addr), *data64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.q(nvmethod(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 1), 0x0)
|
||||
self.nvm(0, nv_gpu.NVC56F_SEM_ADDR_LO, *data64_le(signal.value_addr), *data64_le(value),
|
||||
(1 << 0) | (1 << 20) | (1 << 24) | (1 << 25)) # RELEASE | RELEASE_WFI | PAYLOAD_SIZE_64BIT | RELEASE_TIMESTAMP
|
||||
self.nvm(0, nv_gpu.NVC56F_NON_STALL_INTERRUPT, 0x0)
|
||||
self.active_qmd = None
|
||||
return self
|
||||
|
||||
@@ -178,14 +174,14 @@ class NVComputeQueue(NVCommandQueue):
|
||||
|
||||
class NVCopyQueue(NVCommandQueue):
|
||||
def copy(self, dest:sint, src:sint, copy_size:int):
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, 4), *data64(src), *data64(dest))
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, 1), copy_size)
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
|
||||
self.nvm(4, nv_gpu.NVC6B5_OFFSET_IN_UPPER, *data64(src), *data64(dest))
|
||||
self.nvm(4, nv_gpu.NVC6B5_LINE_LENGTH_IN, copy_size)
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x182) # TRANSFER_TYPE_NON_PIPELINED | DST_MEMORY_LAYOUT_PITCH | SRC_MEMORY_LAYOUT_PITCH
|
||||
return self
|
||||
|
||||
def signal(self, signal:NVSignal, value:sint=0):
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, 3), *data64(signal.value_addr), value)
|
||||
self.q(nvmethod(4, nv_gpu.NVC6B5_LAUNCH_DMA, 1), 0x14)
|
||||
self.nvm(4, nv_gpu.NVC6B5_SET_SEMAPHORE_A, *data64(signal.value_addr), value)
|
||||
self.nvm(4, nv_gpu.NVC6B5_LAUNCH_DMA, 0x14)
|
||||
return self
|
||||
|
||||
def _submit(self, dev:NVDevice): self._submit_to_gpfifo(dev, dev.dma_gpfifo)
|
||||
@@ -267,14 +263,14 @@ class NVProgram(HCQProgram):
|
||||
|
||||
class NVAllocator(HCQAllocator['NVDevice']):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
if options.host: return self.dev._gpu_host_alloc(size, tag="user host memory")
|
||||
return self.dev._gpu_alloc(size, map_to_cpu=options.cpu_access, huge_page=(size > (16 << 20)), tag=f"user memory ({options})")
|
||||
if options.host: return self.dev._gpu_alloc(size, host=True, tag="user host memory")
|
||||
return self.dev._gpu_alloc(size, cpu_access=options.cpu_access, tag=f"user memory ({options})")
|
||||
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
def _free(self, opaque:HCQBuffer, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
self.dev._gpu_free(opaque)
|
||||
|
||||
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if hasattr(buf, '_base') else buf)
|
||||
def map(self, buf:HCQBuffer): self.dev._gpu_map(buf._base if buf._base is not None else buf)
|
||||
|
||||
@dataclass
|
||||
class GPFifo:
|
||||
@@ -292,8 +288,12 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
gpus_info: Union[List, ctypes.Array] = []
|
||||
signals_page: Any = None
|
||||
signals_pool: List[int] = []
|
||||
low_uvm_vaddr: int = 0x1000000000 # 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
|
||||
uvm_vaddr: int = 0x2000000000 # 0x2000000000+
|
||||
|
||||
# TODO: Need a proper allocator for va addresses
|
||||
# 0x1000000000 - 0x2000000000, reserved for system/cpu mappings
|
||||
# VA space is 48bits.
|
||||
low_uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=0x1000000000, start=0x1000000000, wrap=False)
|
||||
uvm_vaddr_allocator: BumpAllocator = BumpAllocator(size=(1 << 48) - 1, start=0x2000000000, wrap=False)
|
||||
host_object_enumerator: int = 0x1000
|
||||
|
||||
def _new_gpu_fd(self):
|
||||
@@ -311,79 +311,70 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
os.close(fd_dev)
|
||||
return res
|
||||
|
||||
def _gpu_alloc(self, size:int, contig=False, huge_page=False, va_addr=None, map_to_cpu=False, map_flags=0, tag=""):
|
||||
size = round_up(size, align:=((2 << 20) if huge_page else (4 << 10)))
|
||||
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, alignment=align, offset=0, limit=size-1, format=6, size=size,
|
||||
attr=(((nv_gpu.NVOS32_ATTR_PAGE_SIZE_HUGE << 23) if huge_page else 0) |
|
||||
((nv_gpu.NVOS32_ATTR_PHYSICALITY_CONTIGUOUS if contig else nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS) << 27)),
|
||||
attr2=((nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC << 0) | (nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_YES << 2) |
|
||||
((nv_gpu.NVOS32_ATTR2_PAGE_SIZE_HUGE_2MB << 20) if huge_page else 0)),
|
||||
flags=(nv_gpu.NVOS32_ALLOC_FLAGS_ALIGNMENT_FORCE | nv_gpu.NVOS32_ALLOC_FLAGS_PERSISTENT_VIDMEM | nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED |
|
||||
nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED))
|
||||
mem_handle = rm_alloc(self.fd_ctl, nv_gpu.NV1_MEMORY_USER, self.root, self.nvdevice, alloc_params).hObjectNew
|
||||
def _gpu_alloc(self, size:int, host=False, uncached=False, cpu_access=False, contiguous=False, map_flags=0, tag="") -> HCQBuffer:
|
||||
# Uncached memory is "system". Use huge pages only for gpu memory.
|
||||
page_size = (4 << 10) if uncached or host else ((2 << 20) if size >= (8 << 20) else (4 << 10))
|
||||
size = round_up(size, page_size)
|
||||
va_addr = self._alloc_gpu_vaddr(size, alignment=page_size, force_low=cpu_access)
|
||||
|
||||
if va_addr is None: va_addr = self._alloc_gpu_vaddr(size, alignment=align, force_low=map_to_cpu)
|
||||
if map_to_cpu: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags)
|
||||
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=map_to_cpu, tag=tag)
|
||||
if host:
|
||||
va_addr = libc.mmap(va_addr, size, mmap.PROT_READ | mmap.PROT_WRITE, MAP_FIXED | mmap.MAP_SHARED | mmap.MAP_ANONYMOUS, -1, 0)
|
||||
|
||||
def _gpu_system_alloc(self, size:int, va_addr=None, map_to_cpu=False, map_flags=0, tag=""):
|
||||
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, type=13,
|
||||
attr=(nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS << 27) | (nv_gpu.NVOS32_ATTR_LOCATION_PCI << 25),
|
||||
attr2=(nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC << 0) | (nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_NO << 2),
|
||||
flags=(nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED |
|
||||
nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED), format=6, size=size, alignment=(4<<10), offset=0, limit=size-1)
|
||||
mem_handle = rm_alloc(self.fd_ctl, nv_gpu.NV1_MEMORY_SYSTEM, self.root, self.nvdevice, alloc_params).hObjectNew
|
||||
flags = (nv_gpu.NVOS02_FLAGS_PHYSICALITY_NONCONTIGUOUS << 4) | (nv_gpu.NVOS02_FLAGS_COHERENCY_CACHED << 12) \
|
||||
| (nv_gpu.NVOS02_FLAGS_MAPPING_NO_MAP << 30)
|
||||
|
||||
if va_addr is None: va_addr = self._alloc_gpu_vaddr(size, force_low=True)
|
||||
if map_to_cpu: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags, system=True)
|
||||
NVDevice.host_object_enumerator += 1
|
||||
made = nv_gpu.nv_ioctl_nvos02_parameters_with_fd(params=nv_gpu.NVOS02_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, flags=flags,
|
||||
hObjectNew=NVDevice.host_object_enumerator, hClass=nv_gpu.NV01_MEMORY_SYSTEM_OS_DESCRIPTOR, pMemory=va_addr, limit=size-1), fd=-1)
|
||||
nv_iowr(self.fd_dev, nv_gpu.NV_ESC_RM_ALLOC_MEMORY, made)
|
||||
|
||||
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=map_to_cpu, tag=tag)
|
||||
if made.params.status != 0: raise RuntimeError(f"host alloc returned {get_error_str(made.params.status)}")
|
||||
mem_handle = made.params.hObjectNew
|
||||
else:
|
||||
attr = ((nv_gpu.NVOS32_ATTR_PHYSICALITY_CONTIGUOUS if contiguous else nv_gpu.NVOS32_ATTR_PHYSICALITY_ALLOW_NONCONTIGUOUS) << 27) \
|
||||
| (nv_gpu.NVOS32_ATTR_PAGE_SIZE_HUGE if page_size > 0x1000 else 0) << 23 | ((nv_gpu.NVOS32_ATTR_LOCATION_PCI if uncached else 0) << 25)
|
||||
|
||||
def _gpu_host_alloc(self, size, tag=""):
|
||||
va_base = self._alloc_gpu_vaddr(aligned_sz:=round_up(size, 4 << 10))
|
||||
mapped_addr = libc.mmap(va_base, aligned_sz, mmap.PROT_READ|mmap.PROT_WRITE, MAP_FIXED|mmap.MAP_SHARED|mmap.MAP_ANONYMOUS, -1, 0)
|
||||
assert mapped_addr == va_base, f"Not mmaped at correct address {va_base=} != {mapped_addr=}"
|
||||
attr2 = ((nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_NO if uncached else nv_gpu.NVOS32_ATTR2_GPU_CACHEABLE_YES) << 2) \
|
||||
| ((nv_gpu.NVOS32_ATTR2_PAGE_SIZE_HUGE_2MB if page_size > 0x1000 else 0) << 20) | nv_gpu.NVOS32_ATTR2_ZBC_PREFER_NO_ZBC
|
||||
|
||||
NVDevice.host_object_enumerator += 1
|
||||
flags = ((nv_gpu.NVOS02_FLAGS_PHYSICALITY_NONCONTIGUOUS << 4) | (nv_gpu.NVOS02_FLAGS_COHERENCY_CACHED << 12) |
|
||||
(nv_gpu.NVOS02_FLAGS_MAPPING_NO_MAP << 30))
|
||||
made = nv_gpu.nv_ioctl_nvos02_parameters_with_fd(params=nv_gpu.NVOS02_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, flags=flags,
|
||||
hObjectNew=NVDevice.host_object_enumerator, hClass=nv_gpu.NV01_MEMORY_SYSTEM_OS_DESCRIPTOR, pMemory=va_base, limit=aligned_sz-1), fd=-1)
|
||||
nv_iowr(self.fd_dev, nv_gpu.NV_ESC_RM_ALLOC_MEMORY, made)
|
||||
fl = nv_gpu.NVOS32_ALLOC_FLAGS_MAP_NOT_REQUIRED | nv_gpu.NVOS32_ALLOC_FLAGS_MEMORY_HANDLE_PROVIDED | nv_gpu.NVOS32_ALLOC_FLAGS_ALIGNMENT_FORCE \
|
||||
| nv_gpu.NVOS32_ALLOC_FLAGS_IGNORE_BANK_PLACEMENT | (nv_gpu.NVOS32_ALLOC_FLAGS_PERSISTENT_VIDMEM if not uncached else 0)
|
||||
|
||||
if made.params.status != 0: raise RuntimeError(f"_map_to_gpu returned {get_error_str(made.params.status)}")
|
||||
return self._gpu_uvm_map(va_base, aligned_sz, made.params.hObjectNew, has_cpu_mapping=True, tag=tag)
|
||||
alloc_func = nv_gpu.NV1_MEMORY_SYSTEM if uncached else nv_gpu.NV1_MEMORY_USER
|
||||
alloc_params = nv_gpu.NV_MEMORY_ALLOCATION_PARAMS(owner=self.root, alignment=page_size, offset=0, limit=size-1, format=6, size=size,
|
||||
type=nv_gpu.NVOS32_TYPE_NOTIFIER if uncached else nv_gpu.NVOS32_TYPE_IMAGE, attr=attr, attr2=attr2, flags=fl)
|
||||
mem_handle = rm_alloc(self.fd_ctl, alloc_func, self.root, self.nvdevice, alloc_params).hObjectNew
|
||||
|
||||
def _gpu_free(self, mem):
|
||||
if mem.hMemory > NVDevice.host_object_enumerator: # not a host object, clear phys mem.
|
||||
made = nv_gpu.NVOS00_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, hObjectOld=mem.hMemory)
|
||||
if cpu_access: va_addr = self._gpu_map_to_cpu(mem_handle, size, target=va_addr, flags=map_flags, system=uncached)
|
||||
|
||||
return self._gpu_uvm_map(va_addr, size, mem_handle, has_cpu_mapping=cpu_access or host, tag=tag)
|
||||
|
||||
def _gpu_free(self, mem:HCQBuffer):
|
||||
if mem.meta.hMemory > NVDevice.host_object_enumerator: # not a host object, clear phys mem.
|
||||
made = nv_gpu.NVOS00_PARAMETERS(hRoot=self.root, hObjectParent=self.nvdevice, hObjectOld=mem.meta.hMemory)
|
||||
nv_iowr(self.fd_ctl, nv_gpu.NV_ESC_RM_FREE, made)
|
||||
if made.status != 0: raise RuntimeError(f"_gpu_free returned {get_error_str(made.status)}")
|
||||
|
||||
self._debug_mappings.pop((mem.va_addr, mem.size))
|
||||
uvm.free(self.fd_uvm, base=mem.va_addr, length=mem.size)
|
||||
if mem.has_cpu_mapping: libc.munmap(mem.va_addr, mem.size)
|
||||
self._debug_mappings.pop((cast(int, mem.va_addr), mem.size))
|
||||
uvm.free(self.fd_uvm, base=cast(int, mem.va_addr), length=mem.size)
|
||||
if mem.meta.has_cpu_mapping: libc.munmap(cast(int, mem.va_addr), mem.size)
|
||||
|
||||
def _gpu_uvm_map(self, va_base, size, mem_handle, create_range=True, has_cpu_mapping=False, tag="") -> nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS:
|
||||
def _gpu_uvm_map(self, va_base, size, mem_handle, create_range=True, has_cpu_mapping=False, tag="") -> HCQBuffer:
|
||||
if create_range: uvm.create_external_range(self.fd_uvm, base=va_base, length=size)
|
||||
attrs = (nv_gpu.struct_c__SA_UvmGpuMappingAttributes*256)(nv_gpu.struct_c__SA_UvmGpuMappingAttributes(gpuUuid=self.gpu_uuid, gpuMappingType=1))
|
||||
|
||||
# NOTE: va_addr is set to make rawbufs compatable with HCQBuffer protocol.
|
||||
self._debug_mappings[(va_base, size)] = tag
|
||||
return uvm.map_external_allocation(self.fd_uvm, base=va_base, length=size, rmCtrlFd=self.fd_ctl, hClient=self.root, hMemory=mem_handle,
|
||||
gpuAttributesCount=1, perGpuAttributes=attrs, va_addr=va_base, size=size, mapped_gpu_ids=[self.gpu_uuid], has_cpu_mapping=has_cpu_mapping)
|
||||
return HCQBuffer(va_base, size, meta=uvm.map_external_allocation(self.fd_uvm, base=va_base, length=size, rmCtrlFd=self.fd_ctl, hClient=self.root,
|
||||
hMemory=mem_handle, gpuAttributesCount=1, perGpuAttributes=attrs, mapped_gpu_ids=[self.gpu_uuid], has_cpu_mapping=has_cpu_mapping))
|
||||
|
||||
def _gpu_map(self, mem):
|
||||
if self.gpu_uuid in mem.mapped_gpu_ids: return
|
||||
mem.mapped_gpu_ids.append(self.gpu_uuid)
|
||||
self._gpu_uvm_map(mem.va_addr, mem.size, mem.hMemory, create_range=False, tag="p2p mem")
|
||||
def _gpu_map(self, mem:HCQBuffer):
|
||||
if self.gpu_uuid in mem.meta.mapped_gpu_ids: return
|
||||
mem.meta.mapped_gpu_ids.append(self.gpu_uuid)
|
||||
self._gpu_uvm_map(mem.va_addr, mem.size, mem.meta.hMemory, create_range=False, tag="p2p mem")
|
||||
|
||||
def _alloc_gpu_vaddr(self, size, alignment=(4 << 10), force_low=False):
|
||||
if force_low:
|
||||
NVDevice.low_uvm_vaddr = (res_va:=round_up(NVDevice.low_uvm_vaddr, alignment)) + size
|
||||
assert NVDevice.low_uvm_vaddr < 0x2000000000, "Exceed low vm addresses"
|
||||
else: NVDevice.uvm_vaddr = (res_va:=round_up(NVDevice.uvm_vaddr, alignment)) + size
|
||||
return res_va
|
||||
return NVDevice.low_uvm_vaddr_allocator.alloc(size, alignment) if force_low else NVDevice.uvm_vaddr_allocator.alloc(size, alignment)
|
||||
|
||||
def _setup_nvclasses(self):
|
||||
classlist = memoryview(bytearray(100 * 4)).cast('I')
|
||||
@@ -441,14 +432,14 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
except RuntimeError as e: raise RuntimeError(str(e) + f". Make sure GPUs #{self.gpu_minor} & #{dev.gpu_minor} have P2P enabled between.") from e
|
||||
|
||||
if NVDevice.signals_page is None:
|
||||
NVDevice.signals_page = self._gpu_system_alloc(16 * 65536, map_to_cpu=True)
|
||||
NVDevice.signals_page = self._gpu_alloc(16 * 65536, cpu_access=True, uncached=True)
|
||||
NVDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, NVDevice.signals_page.size, 16)]
|
||||
else: self._gpu_map(NVDevice.signals_page)
|
||||
|
||||
channel_params = nv_gpu.NV_CHANNEL_GROUP_ALLOCATION_PARAMETERS(engineType=nv_gpu.NV2080_ENGINE_TYPE_GRAPHICS)
|
||||
channel_group = rm_alloc(self.fd_ctl, nv_gpu.KEPLER_CHANNEL_GROUP_A, self.root, self.nvdevice, channel_params).hObjectNew
|
||||
|
||||
gpfifo_area = self._gpu_alloc(0x200000, contig=True, huge_page=True, map_to_cpu=True, map_flags=0x10d0000, tag="gpfifo")
|
||||
gpfifo_area = self._gpu_alloc(0x200000, contiguous=True, cpu_access=True, map_flags=0x10d0000, tag="gpfifo")
|
||||
|
||||
ctxshare_params = nv_gpu.NV_CTXSHARE_ALLOCATION_PARAMETERS(hVASpace=vaspace, flags=nv_gpu.NV_CTXSHARE_ALLOCATION_FLAGS_SUBCONTEXT_ASYNC)
|
||||
ctxshare = rm_alloc(self.fd_ctl, nv_gpu.FERMI_CONTEXT_SHARE_A, self.root, channel_group, ctxshare_params).hObjectNew
|
||||
@@ -458,9 +449,9 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
|
||||
rmctrl.gpfifo_schedule(self.fd_ctl, self.root, channel_group, bEnable=1)
|
||||
|
||||
self.cmdq_page: nv_gpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS = self._gpu_alloc(0x200000, map_to_cpu=True, huge_page=True, tag="cmdq")
|
||||
self.cmdq: memoryview = to_mv(self.cmdq_page.va_addr, 0x200000).cast("I")
|
||||
self.cmdq_wptr: int = 0 # in bytes
|
||||
self.cmdq_page:HCQBuffer = self._gpu_alloc(0x200000, cpu_access=True, tag="cmdq")
|
||||
self.cmdq_allocator = BumpAllocator(size=self.cmdq_page.size, start=cast(int, self.cmdq_page.va_addr), wrap=True)
|
||||
self.cmdq: memoryview = to_mv(cast(int, self.cmdq_page.va_addr), 0x200000).cast("I")
|
||||
|
||||
self.num_gpcs, self.num_tpc_per_gpc, self.num_sm_per_tpc, self.max_warps_per_sm, self.sm_version = self._query_gpu_info('num_gpcs',
|
||||
'num_tpc_per_gpc', 'num_sm_per_tpc', 'max_warps_per_sm', 'sm_version')
|
||||
@@ -473,10 +464,10 @@ class NVDevice(HCQCompiled[NVSignal]):
|
||||
self._setup_gpfifos()
|
||||
|
||||
def _new_gpu_fifo(self, gpfifo_area, ctxshare, channel_group, offset=0, entries=0x400, enable_debug=False) -> GPFifo:
|
||||
notifier = self._gpu_system_alloc(48 << 20)
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(hObjectError=notifier.hMemory, hObjectBuffer=gpfifo_area.hMemory,
|
||||
notifier = self._gpu_alloc(48 << 20, uncached=True)
|
||||
params = nv_gpu.NV_CHANNELGPFIFO_ALLOCATION_PARAMETERS(hObjectError=notifier.meta.hMemory, hObjectBuffer=gpfifo_area.meta.hMemory,
|
||||
gpFifoOffset=gpfifo_area.va_addr+offset, gpFifoEntries=entries, hContextShare=ctxshare,
|
||||
hUserdMemory=(ctypes.c_uint32*8)(gpfifo_area.hMemory), userdOffset=(ctypes.c_uint64*8)(entries*8+offset))
|
||||
hUserdMemory=(ctypes.c_uint32*8)(gpfifo_area.meta.hMemory), userdOffset=(ctypes.c_uint64*8)(entries*8+offset))
|
||||
gpfifo = rm_alloc(self.fd_ctl, nv_gpu.AMPERE_CHANNEL_GPFIFO_A, self.root, channel_group, params).hObjectNew
|
||||
comp = rm_alloc(self.fd_ctl, self.compute_class, self.root, gpfifo, None).hObjectNew
|
||||
rm_alloc(self.fd_ctl, nv_gpu.AMPERE_DMA_COPY_B, self.root, gpfifo, None)
|
||||
|
||||
@@ -4,11 +4,11 @@ assert sys.platform != 'win32'
|
||||
from types import SimpleNamespace
|
||||
from typing import Tuple, List, Any, cast, Optional
|
||||
from tinygrad.device import BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
|
||||
from tinygrad.runtime.autogen import kgsl, adreno, libc
|
||||
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
|
||||
from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.helpers import getenv, from_mv, mv_address, to_mv, round_up, data64_le, prod, fromimport
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
BUFTYPE_BUF, BUFTYPE_TEX, BUFTYPE_IBO = 0, 1, 2
|
||||
@@ -86,7 +86,7 @@ class QCOMComputeQueue(HWQueue):
|
||||
return self
|
||||
|
||||
def _build_gpu_command(self, dev:QCOMDevice, hw_addr=None):
|
||||
to_mv((hw_page_addr:=hw_addr or dev._alloc_cmd_buf(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
|
||||
to_mv((hw_page_addr:=hw_addr or dev.cmd_buf_allocator.alloc(len(self._q) * 4)), len(self._q) * 4).cast('I')[:] = array.array('I', self._q)
|
||||
obj = kgsl.struct_kgsl_command_object(gpuaddr=hw_page_addr, size=len(self._q) * 4, flags=kgsl.KGSL_CMDLIST_IB)
|
||||
submit_req = kgsl.struct_kgsl_gpu_command(cmdlist=ctypes.addressof(obj), numcmds=1, context_id=dev.ctx,
|
||||
cmdsize=ctypes.sizeof(kgsl.struct_kgsl_command_object))
|
||||
@@ -105,6 +105,8 @@ class QCOMComputeQueue(HWQueue):
|
||||
dev.last_cmd = kgsl.IOCTL_KGSL_GPU_COMMAND(dev.fd, __payload=submit_req).timestamp
|
||||
|
||||
def exec(self, prg:QCOMProgram, args_state:QCOMArgsState, global_size, local_size):
|
||||
self.bind_args_state(args_state)
|
||||
|
||||
def cast_int(x, ceil=False): return (math.ceil(x) if ceil else int(x)) if isinstance(x, float) else x
|
||||
global_size_mp = [cast_int(g*l) for g,l in zip(global_size, local_size)]
|
||||
|
||||
@@ -147,7 +149,7 @@ class QCOMComputeQueue(HWQueue):
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.samp_cnt),
|
||||
*data64_le(args_state.ptr + args_state.prg.samp_off))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_SAMP, *data64_le(args_state.ptr + args_state.prg.samp_off))
|
||||
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev._border_color_base()))
|
||||
self.reg(adreno.REG_A6XX_SP_PS_TP_BORDER_COLOR_BASE_ADDR, *data64_le(prg.dev.border_color_buf.va_addr))
|
||||
|
||||
if args_state.prg.tex_cnt > 0:
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
||||
@@ -179,17 +181,13 @@ class QCOMArgsState(HCQArgsState):
|
||||
for cnst_val, cnst_off, cnst_sz in prg.consts_info: to_mv(self.ptr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.ptr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
for i, b in enumerate(cast(List[QCOMBuffer], bufs)):
|
||||
if prg.buf_info[i].type is BUFTYPE_TEX: to_mv(self.ptr + prg.buf_info[i].offset, len(b.desc) * 4).cast('I')[:] = array.array('I', b.desc)
|
||||
elif prg.buf_info[i].type is BUFTYPE_IBO: to_mv(self.ptr + prg.buf_info[i].offset, len(b.ibo) * 4).cast('I')[:] = array.array('I', b.ibo)
|
||||
else: self.update_buffer(i, b)
|
||||
for i, v in enumerate(vals): self.update_var(i, v)
|
||||
for i, b in enumerate(bufs):
|
||||
if prg.buf_info[i].type in {BUFTYPE_TEX, BUFTYPE_IBO}:
|
||||
obj = b.texture_info.desc if prg.buf_info[i].type is BUFTYPE_TEX else b.texture_info.ibo
|
||||
to_mv(self.ptr + prg.buf_info[i].offset, len(obj) * 4).cast('I')[:] = array.array('I', obj)
|
||||
self.bind_sints_to_ptr(b.va_addr, ptr=self.ptr + self.buf_info[i].offset + (0 if self.buf_info[i].type is BUFTYPE_BUF else 16), fmt='Q')
|
||||
|
||||
def update_buffer(self, index:int, buf:HCQBuffer):
|
||||
if self.buf_info[index].type is not BUFTYPE_BUF: self.args_view[self.buf_info[index].offset//8 + 2] = buf.va_addr
|
||||
else: self.args_view[self.buf_info[index].offset//8] = buf.va_addr
|
||||
|
||||
def update_var(self, index:int, val:int): self.args_view[self.args_info[index].offset//8] = val
|
||||
for i, v in enumerate(vals): self.bind_sints_to_ptr(v, ptr=self.ptr + self.args_info[i].offset, fmt='I')
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
||||
@@ -198,7 +196,7 @@ class QCOMProgram(HCQProgram):
|
||||
self._parse_lib()
|
||||
|
||||
self.lib_gpu: HCQBuffer = self.dev.allocator.alloc(self.image_size, options=BufferSpec(cpu_access=True, nolru=True))
|
||||
to_mv(self.lib_gpu.va_addr, self.image_size)[:] = self.image
|
||||
to_mv(cast(int, self.lib_gpu.va_addr), self.image_size)[:] = self.image
|
||||
|
||||
self.pvtmem_size_per_item: int = round_up(self.pvtmem, 512) >> 9
|
||||
self.pvtmem_size_total: int = self.pvtmem_size_per_item * 128 * 2
|
||||
@@ -269,15 +267,13 @@ class QCOMProgram(HCQProgram):
|
||||
def __del__(self):
|
||||
if hasattr(self, 'lib_gpu'): self.dev.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferSpec(cpu_access=True, nolru=True))
|
||||
|
||||
class QCOMBuffer(HCQBuffer):
|
||||
def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
|
||||
self.va_addr, self.size, self.info, self.mapped = va_addr, size, info, mapped
|
||||
class QCOMTextureInfo:
|
||||
def __init__(self, pitch:int, real_stride:int, desc:List[int], ibo:List[int]):
|
||||
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
|
||||
|
||||
# Texture specific definitions
|
||||
self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
|
||||
|
||||
class QCOMAllocator(HCQAllocator):
|
||||
class QCOMAllocator(HCQAllocatorBase):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
# Recalculate real size for texture
|
||||
if options.image is not None:
|
||||
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
@@ -286,22 +282,18 @@ class QCOMAllocator(HCQAllocator):
|
||||
granularity = 128 if options.image.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
size = pitch * imgh
|
||||
|
||||
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
|
||||
else: texture = self.dev._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE)
|
||||
|
||||
texture.pitch, texture.real_stride = pitch, real_stride
|
||||
buf = HCQBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
||||
|
||||
if options.image is not None:
|
||||
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
|
||||
texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
|
||||
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
|
||||
texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
|
||||
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
|
||||
qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
|
||||
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
|
||||
return texture
|
||||
|
||||
return QCOMBuffer(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
||||
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
|
||||
return buf
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, dest_off=0, src_off=0):
|
||||
while src_off < src_size:
|
||||
@@ -309,17 +301,18 @@ class QCOMAllocator(HCQAllocator):
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
if (qd:=cast(QCOMBuffer, dest)).pitch is not None: self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
|
||||
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
|
||||
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
self._do_copy(mv_address(src), dest.va_addr, src.nbytes, stride, stride, pitch)
|
||||
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.dev.synchronize()
|
||||
if (qs:=cast(QCOMBuffer, src)).pitch is not None: self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
|
||||
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
|
||||
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
self._do_copy(src.va_addr, mv_address(dest), src.size, stride, pitch, stride)
|
||||
|
||||
def _as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
self.dev.synchronize()
|
||||
return to_mv(src.va_addr, src.size)
|
||||
return to_mv(cast(int, src.va_addr), src.size)
|
||||
|
||||
def _free(self, opaque, options:BufferSpec):
|
||||
self.dev.synchronize()
|
||||
@@ -333,32 +326,35 @@ class QCOMDevice(HCQCompiled):
|
||||
|
||||
def __init__(self, device:str=""):
|
||||
self.fd = os.open('/dev/kgsl-3d0', os.O_RDWR)
|
||||
QCOMDevice.dummy_addr = self._gpu_alloc(0x1000).va_addr
|
||||
QCOMDevice.dummy_addr = cast(int, self._gpu_alloc(0x1000).va_addr)
|
||||
QCOMDevice.signals_page = self._gpu_alloc(16 * 65536, uncached=True)
|
||||
QCOMDevice.signals_pool = [self.signals_page.va_addr + off for off in range(0, self.signals_page.size, 16)]
|
||||
info, self.ctx, self.cmd_buf, self.cmd_buf_ptr, self.last_cmd = self._info(), self._ctx_create(), self._gpu_alloc(16 << 20), 0,0
|
||||
|
||||
flags = kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT | kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC \
|
||||
| kgsl.KGSL_CONTEXT_PRIORITY(8) | kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)
|
||||
self.ctx = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=flags).drawctxt_id
|
||||
|
||||
self.cmd_buf = self._gpu_alloc(16 << 20)
|
||||
self.cmd_buf_allocator = BumpAllocator(size=self.cmd_buf.size, start=cast(int, self.cmd_buf.va_addr), wrap=True)
|
||||
|
||||
self.border_color_buf = self._gpu_alloc(0x1000, fill_zeroes=True)
|
||||
|
||||
self.last_cmd:int = 0
|
||||
|
||||
# Set max power
|
||||
struct.pack_into('IIQQ', pwr:=memoryview(bytearray(0x18)), 0, 1, self.ctx, mv_address(_:=memoryview(array.array('I', [1]))), 4)
|
||||
kgsl.IOCTL_KGSL_SETPROPERTY(self.fd, type=kgsl.KGSL_PROP_PWR_CONSTRAINT, value=mv_address(pwr), sizebytes=pwr.nbytes)
|
||||
|
||||
# Load info about qcom device
|
||||
info = kgsl.struct_kgsl_devinfo()
|
||||
kgsl.IOCTL_KGSL_DEVICE_GETPROPERTY(self.fd, type=kgsl.KGSL_PROP_DEVICE_INFO, value=ctypes.addressof(info), sizebytes=ctypes.sizeof(info))
|
||||
QCOMDevice.gpu_id = ((info.chip_id >> 24) & 0xFF) * 100 + ((info.chip_id >> 16) & 0xFF) * 10 + ((info.chip_id >> 8) & 0xFF)
|
||||
if QCOMDevice.gpu_id >= 700: raise RuntimeError(f"Unsupported GPU: {QCOMDevice.gpu_id}")
|
||||
|
||||
super().__init__(device, QCOMAllocator(self), QCOMRenderer(), QCOMCompiler(device), functools.partial(QCOMProgram, self),
|
||||
QCOMSignal, QCOMComputeQueue, None)
|
||||
|
||||
def _ctx_create(self):
|
||||
cr = kgsl.IOCTL_KGSL_DRAWCTXT_CREATE(self.fd, flags=(kgsl.KGSL_CONTEXT_PREAMBLE | kgsl.KGSL_CONTEXT_PWR_CONSTRAINT |
|
||||
kgsl.KGSL_CONTEXT_NO_FAULT_TOLERANCE | kgsl.KGSL_CONTEXT_NO_GMEM_ALLOC | kgsl.KGSL_CONTEXT_PRIORITY(8) |
|
||||
kgsl.KGSL_CONTEXT_PREEMPT_STYLE(kgsl.KGSL_CONTEXT_PREEMPT_STYLE_FINEGRAIN)))
|
||||
|
||||
# Set power to maximum.
|
||||
struct.pack_into('IIQQ', pwr:=memoryview(bytearray(0x18)), 0, 1, cr.drawctxt_id, mv_address(_:=memoryview(array.array('I', [1]))), 4)
|
||||
kgsl.IOCTL_KGSL_SETPROPERTY(self.fd, type=kgsl.KGSL_PROP_PWR_CONSTRAINT, value=mv_address(pwr), sizebytes=pwr.nbytes)
|
||||
return cr.drawctxt_id
|
||||
|
||||
def _info(self):
|
||||
info = kgsl.struct_kgsl_devinfo()
|
||||
kgsl.IOCTL_KGSL_DEVICE_GETPROPERTY(self.fd, type=kgsl.KGSL_PROP_DEVICE_INFO, value=ctypes.addressof(info), sizebytes=ctypes.sizeof(info))
|
||||
return info
|
||||
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False):
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
|
||||
flags |= kgsl.KGSL_MEMALIGN(alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
|
||||
if uncached: flags |= kgsl.KGSL_CACHEMODE(kgsl.KGSL_CACHEMODE_UNCACHED)
|
||||
|
||||
@@ -366,19 +362,11 @@ class QCOMDevice(HCQCompiled):
|
||||
va_addr = libc.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, self.fd, alloc.id * 0x1000)
|
||||
|
||||
if fill_zeroes: ctypes.memset(va_addr, 0, size)
|
||||
return QCOMBuffer(va_addr=va_addr, size=size, info=alloc)
|
||||
return HCQBuffer(va_addr=va_addr, size=size, meta=alloc)
|
||||
|
||||
def _gpu_free(self, mem):
|
||||
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.info.id)
|
||||
libc.munmap(mem.va_addr, mem.info.mmapsize)
|
||||
|
||||
def _alloc_cmd_buf(self, sz: int):
|
||||
self.cmd_buf_ptr = (cur_ptr:=self.cmd_buf_ptr if self.cmd_buf_ptr + sz < self.cmd_buf.size else 0) + sz
|
||||
return self.cmd_buf.va_addr + cur_ptr
|
||||
|
||||
def _border_color_base(self):
|
||||
if not hasattr(self, '_border_color_gpu'): self._border_color_gpu = self._gpu_alloc(0x1000, fill_zeroes=True)
|
||||
return self._border_color_gpu.va_addr
|
||||
def _gpu_free(self, mem:HCQBuffer):
|
||||
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.meta.id)
|
||||
libc.munmap(mem.va_addr, mem.meta.mmapsize)
|
||||
|
||||
def _ensure_stack_size(self, sz):
|
||||
if not hasattr(self, '_stack'): self._stack = self._gpu_alloc(sz)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Dict, Tuple, cast, Protocol, Type, Union, TypeVar, Generic, Any
|
||||
from typing import List, Optional, Dict, Tuple, cast, Type, Union, TypeVar, Generic, Any
|
||||
import contextlib, decimal, statistics, random, json, atexit, time, ctypes, array
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv
|
||||
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv, getenv, to_mv, round_up
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.device import BufferSpec, Compiler, Compiled, LRUAllocator
|
||||
from tinygrad.ops import sym_infer, sint, Variable
|
||||
@@ -14,6 +14,15 @@ ProgramType = TypeVar('ProgramType', bound='HCQProgram')
|
||||
ArgsStateType = TypeVar('ArgsStateType', bound='HCQArgsState')
|
||||
QueueType = TypeVar('QueueType', bound='HWQueue')
|
||||
|
||||
class BumpAllocator:
|
||||
def __init__(self, size:int, start:int=0, wrap:bool=True): self.size, self.ptr, self.start_off, self.wrap = size, 0, start, wrap
|
||||
def alloc(self, size:int, alignment:int=1) -> int:
|
||||
if round_up(self.ptr, alignment) + size > self.size:
|
||||
if not self.wrap: raise RuntimeError("Out of memory")
|
||||
self.ptr = 0
|
||||
self.ptr = (res:=round_up(self.ptr, alignment)) + size
|
||||
return res + self.start_off
|
||||
|
||||
class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
"""
|
||||
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
|
||||
@@ -121,6 +130,9 @@ class HWQueue(Generic[SignalType, DeviceType, ProgramType, ArgsStateType]):
|
||||
Implementing this method is optional but recommended for performance gains.
|
||||
"""
|
||||
|
||||
def bind_args_state(self, args_state:ArgsStateType):
|
||||
for vals, ptr, fmt in args_state.bind_data: self.bind_sints_to_ptr(*vals, ptr=ptr, fmt=fmt)
|
||||
|
||||
def bind_sints(self, *vals:sint, struct:ctypes.Structure, start_field:str, fmt, mask:Optional[int]=None):
|
||||
self.bind_sints_to_ptr(*vals, ptr=ctypes.addressof(struct) + getattr(type(struct), start_field).offset, fmt=fmt, mask=mask)
|
||||
|
||||
@@ -224,24 +236,20 @@ def hcq_profile(dev:HCQCompiled, enabled, desc, queue_type:Optional[Type[HWQueue
|
||||
if enabled and PROFILE: dev.sig_prof_records.append((cast(HCQSignal, st), cast(HCQSignal, en), desc, queue_type is dev.hw_copy_queue_t))
|
||||
|
||||
class HCQArgsState(Generic[ProgramType]):
|
||||
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
|
||||
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
|
||||
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
|
||||
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[sint, ...]=()):
|
||||
self.ptr, self.prg = ptr, prg
|
||||
self.bind_data:List[Tuple[Tuple[sint, ...], int, str]] = []
|
||||
|
||||
def bind_sints_to_ptr(self, *vals:sint, ptr:int, fmt): self.bind_data.append((vals, ptr, fmt))
|
||||
|
||||
class CLikeArgsState(HCQArgsState[ProgramType]):
|
||||
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), prefix:Optional[List[int]]=None):
|
||||
def __init__(self, ptr:int, prg:ProgramType, bufs:Tuple[HCQBuffer, ...], vals:Tuple[sint, ...]=(), prefix:Optional[List[int]]=None):
|
||||
super().__init__(ptr, prg, bufs, vals=vals)
|
||||
|
||||
if prefix is not None: to_mv(self.ptr, len(prefix) * 4).cast('I')[:] = array.array('I', prefix)
|
||||
|
||||
self.bufs = to_mv(self.ptr + len(prefix or []) * 4, len(bufs) * 8).cast('Q')
|
||||
self.vals = to_mv(self.ptr + len(prefix or []) * 4 + len(bufs) * 8, len(vals) * 4).cast('I')
|
||||
|
||||
self.bufs[:] = array.array('Q', [b.va_addr for b in bufs])
|
||||
self.vals[:] = array.array('I', vals)
|
||||
|
||||
def update_buffer(self, index:int, buf:HCQBuffer): self.bufs[index] = buf.va_addr
|
||||
def update_var(self, index:int, val:int): self.vals[index] = val
|
||||
self.bind_sints_to_ptr(*[b.va_addr for b in bufs], ptr=self.ptr + len(prefix or []) * 4, fmt='Q')
|
||||
self.bind_sints_to_ptr(*vals, ptr=self.ptr + len(prefix or []) * 4 + len(bufs) * 8, fmt='I')
|
||||
|
||||
class HCQProgram(Generic[DeviceType]):
|
||||
def __init__(self, args_state_t:Type[HCQArgsState], dev:DeviceType, name:str, kernargs_alloc_size:int):
|
||||
@@ -257,7 +265,7 @@ class HCQProgram(Generic[DeviceType]):
|
||||
Returns:
|
||||
Arguments state with the given buffers and values set for the program.
|
||||
"""
|
||||
return self.args_state_t(kernargs_ptr or self.dev._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
|
||||
return self.args_state_t(kernargs_ptr or self.dev.kernargs_alloctor.alloc(self.kernargs_alloc_size), self, bufs, vals=vals)
|
||||
|
||||
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
|
||||
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
|
||||
@@ -333,7 +341,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
|
||||
|
||||
def __init__(self, device:str, allocator:HCQAllocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
def __init__(self, device:str, allocator:HCQAllocatorBase, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[SignalType],
|
||||
comp_queue_t:Type[HWQueue], copy_queue_t:Optional[Type[HWQueue]]):
|
||||
self.device_id:int = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
|
||||
@@ -349,7 +357,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
|
||||
|
||||
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferSpec(cpu_access=True))
|
||||
self.kernargs_ptr:int = self.kernargs_page.va_addr
|
||||
self.kernargs_alloctor:BumpAllocator = BumpAllocator(self.kernargs_page.size, start=cast(int, self.kernargs_page.va_addr), wrap=True)
|
||||
self.devices.append(self)
|
||||
|
||||
def synchronize(self):
|
||||
@@ -363,14 +371,6 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
|
||||
self.sig_prof_records = []
|
||||
|
||||
def _alloc_kernargs(self, alloc_size:int) -> int:
|
||||
"""
|
||||
Allocates space for arguments passed to the kernel.
|
||||
"""
|
||||
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
|
||||
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
|
||||
return res
|
||||
|
||||
def _ensure_shared_time_base(self):
|
||||
if not self.gpu2cpu_compute_time_diff.is_nan(): return
|
||||
|
||||
@@ -445,12 +445,13 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
def _wrap_timeline_signal(self):
|
||||
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
|
||||
self.timeline_signal.value = 0
|
||||
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
|
||||
cast(HCQAllocatorBase, self.allocator).b_timeline = [0] * len(cast(HCQAllocatorBase, self.allocator).b)
|
||||
|
||||
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
|
||||
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
|
||||
class HCQBuffer:
|
||||
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:Optional[HCQBuffer]=None):
|
||||
self.va_addr, self.size, self.texture_info, self.meta, self._base = va_addr, size, texture_info, meta, _base
|
||||
|
||||
class HCQAllocator(LRUAllocator, Generic[DeviceType]):
|
||||
class HCQAllocatorBase(LRUAllocator, Generic[DeviceType]):
|
||||
"""
|
||||
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
|
||||
|
||||
@@ -463,8 +464,12 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]):
|
||||
self.b_timeline, self.b_next = [0] * len(self.b), 0
|
||||
super().__init__()
|
||||
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
|
||||
def map(self, buf:HCQBuffer): pass
|
||||
|
||||
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
||||
return HCQBuffer(va_addr=buf.va_addr + offset, size=size, texture_info=buf.texture_info, meta=buf.meta, _base=buf._base or buf)
|
||||
|
||||
class HCQAllocator(HCQAllocatorBase, Generic[DeviceType]):
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
assert self.dev.hw_copy_queue_t is not None
|
||||
with hcq_profile(self.dev, queue_type=self.dev.hw_copy_queue_t, desc=f"CPU -> {self.dev.device}", enabled=PROFILE):
|
||||
@@ -525,9 +530,3 @@ class HCQAllocator(LRUAllocator, Generic[DeviceType]):
|
||||
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
||||
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
||||
dest_dev.timeline_value += 1
|
||||
|
||||
def map(self, buf:HCQBuffer): pass
|
||||
|
||||
def _offset(self, buf, size:int, offset:int) -> HCQBuffer:
|
||||
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
|
||||
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Sequence
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop
|
||||
from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv
|
||||
@@ -95,10 +95,11 @@ class View:
|
||||
for x in self.shape+self.strides+(self.offset,)+(tuple(flatten(self.mask)) if self.mask is not None else tuple()))
|
||||
def __lt__(self, o:View): return self.t < o.t
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]|Tuple[UOp, ...]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
def to_indexed_uops(self:View, idxs:Optional[Sequence[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
"""(idx, valid)"""
|
||||
if idxs is None: idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)]
|
||||
iexpr = sint_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else itertools.repeat(None)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * (idx >= m[0])
|
||||
@@ -267,11 +268,9 @@ class View:
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def expand(self, new_shape: Tuple[sint, ...]) -> View:
|
||||
if len(new_shape) != len(self.shape): raise ValueError(f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}")
|
||||
if 0 in self.shape:
|
||||
assert all((s == x == 0) or (s > 0 and (x % s) == 0) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
return View.create(new_shape)
|
||||
# TODO: this resolve might be wrong
|
||||
assert all((not resolve(s != x, False) or s == 1) for s,x in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
# NOTE: does not check multiple of symbolic shape
|
||||
assert all(resolve(s == ns) or s == 1 for s,ns in zip(self.shape, new_shape)), f"can't expand {self.shape} into {new_shape}"
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
# NOTE: can the mask ever be (0,0)?
|
||||
# TODO: this resolve may not be needed, but it's hard because vars need to be sorted
|
||||
mask = tuple([(((0,0) if m != (0,1) else (0,ns)) if resolve(s != ns, False) else m) \
|
||||
@@ -299,16 +298,13 @@ class View:
|
||||
def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]:
|
||||
if self.shape == new_shape: return self
|
||||
|
||||
# TODO: this resolve shouldn't be needed
|
||||
assert all(resolve(x >= 0) for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
||||
if 0 in self.shape:
|
||||
assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}"
|
||||
return View.create(new_shape)
|
||||
assert all(x >= 0 for x in new_shape), f"shape can't contain negative numbers {new_shape}"
|
||||
# check for the same size
|
||||
if (self_all_int := all_int(self.shape)):
|
||||
assert all(isinstance(s, (int, UOp)) for s in new_shape), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim"
|
||||
if resolve(prod(self.shape) != prod(new_shape), False): raise ValueError(f"size mismatched, can't reshape {self.shape=} -> {new_shape=}")
|
||||
|
||||
if 0 in self.shape: return View.create(new_shape)
|
||||
if new_shape == () and self.mask and any(mx==my for (mx,my) in self.mask): return None
|
||||
|
||||
# after the asserts, it's okay to check contiguous
|
||||
@@ -330,14 +326,12 @@ class View:
|
||||
while resolve(acc <= merged_dim) and resolve(acc != merged_dim) and resolve((new_dim := next(r_new_shape, 0)) > 0):
|
||||
strides.append(new_stride)
|
||||
if resolve(new_dim != 1): new_stride *= (new_dim if resolve((acc := acc * new_dim) < real_dim) else 0)
|
||||
if resolve(acc != merged_dim): break
|
||||
else:
|
||||
strides += [0,] * (len(new_shape) - len(strides))
|
||||
new_mask = _reshape_mask(self.mask, self.shape, new_shape)
|
||||
if new_mask is not None:
|
||||
new_strides = canonicalize_strides(tuple(e-b for b,e in new_mask), tuple(reversed(strides)))
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
if resolve(acc != merged_dim): return None
|
||||
|
||||
if (new_mask:=_reshape_mask(self.mask, self.shape, new_shape)) is not None:
|
||||
new_strides = (0,) * (len(new_shape) - len(strides)) + tuple(strides[::-1])
|
||||
extra_offset = (sum(m[0] * s for m,s in zip(self.mask, self.strides)) if self.mask else 0) - \
|
||||
(sum(m[0] * s for m,s in zip(new_mask, new_strides)))
|
||||
return View.create(new_shape, new_strides, self.offset + extra_offset, new_mask)
|
||||
|
||||
return None
|
||||
|
||||
@@ -89,11 +89,12 @@ def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor:
|
||||
assert isinstance(ret, Tensor), "sum didn't return a Tensor"
|
||||
return ret
|
||||
|
||||
def _pad_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
||||
def _align_left(*shapes:Tuple[sint, ...]) -> Tuple[Tuple[sint, ...], ...]:
|
||||
# unsqueeze left to make every shape same length
|
||||
max_dim = max(len(shape) for shape in shapes)
|
||||
return tuple((1,) * (max_dim - len(shape)) + shape for shape in shapes)
|
||||
def _broadcast_shape(*shapes:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_pad_left(*shapes)))
|
||||
return tuple(0 if 0 in nth_dim_sizes else smax(nth_dim_sizes) for nth_dim_sizes in zip(*_align_left(*shapes)))
|
||||
|
||||
def _masked_setitem(target:Tensor, values:Tensor, mask:Tensor, axes:Tuple[int, ...]):
|
||||
# apply mask to values (already broadcasted) and reduce such that if mask contains repeated indices the last one remains
|
||||
@@ -281,7 +282,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert self.dtype.base.fmt is not None, f"no fmt dtype for {self.dtype.base}"
|
||||
assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}"
|
||||
if TYPE_CHECKING or sys.version_info < (3, 12): assert self.dtype.base.fmt != "e"
|
||||
return self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape)
|
||||
return cast(memoryview, self._data().cast(self.dtype.base.fmt) if 0 in self.shape else self._data().cast(self.dtype.base.fmt, self.shape))
|
||||
|
||||
def item(self) -> ConstType:
|
||||
"""
|
||||
@@ -366,14 +367,13 @@ class Tensor(SimpleMathTrait):
|
||||
assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer"
|
||||
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
||||
if axis is not None:
|
||||
if axis < 0: axis += len(self.shape)
|
||||
axis = self._resolve_dim(axis)
|
||||
if splits is None:
|
||||
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
||||
sz = ceildiv(total, len(devices))
|
||||
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
||||
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
||||
boundaries = tuple(itertools.accumulate(splits))
|
||||
bounds = tuple(zip((0,) + boundaries, boundaries))
|
||||
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
|
||||
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None, splits:Optional[Tuple[int, ...]]=None):
|
||||
@@ -947,7 +947,8 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.expand(4, -1).numpy())
|
||||
```
|
||||
"""
|
||||
return self._broadcast_to(tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_pad_left(self.shape, argfix(shape, *args))))))
|
||||
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
||||
return self._broadcast_to(new_shape)
|
||||
|
||||
def permute(self, order, *args) -> Tensor:
|
||||
"""
|
||||
@@ -1117,7 +1118,7 @@ class Tensor(SimpleMathTrait):
|
||||
case list() | tuple() | Tensor():
|
||||
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
||||
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
||||
index = (index < 0).where(size, 0) + index # treat negative index values
|
||||
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
|
||||
case int() | UOp(): # sint
|
||||
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
||||
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
||||
@@ -1162,8 +1163,7 @@ class Tensor(SimpleMathTrait):
|
||||
for dim, tensor in zip(dims, tensors):
|
||||
try: i = tensor.reshape(tensor.shape + (1,)*(x.ndim - dims[0])).expand(pre_reduce_shape)
|
||||
except ValueError as e: raise IndexError(f"cannot broadcast indices: {e}") from e
|
||||
a = Tensor.arange(x.shape[dim], device=self.device, requires_grad=False).reshape((x.shape[dim],) + (1,)*(x.ndim - dim - 1))
|
||||
masks.append(i == a)
|
||||
masks.append(i._one_hot_along_dim(num_classes=x.shape[dim], dim=(dim - x.ndim)))
|
||||
|
||||
# reduce masks to 1 mask
|
||||
mask: Tensor = functools.reduce(lambda x,y: x.mul(y), masks)
|
||||
@@ -1225,7 +1225,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim"
|
||||
index = index.to(self.device)
|
||||
x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)
|
||||
return (x * index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim])).sum(-1, acc_dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
@@ -1289,7 +1289,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
repeats = argfix(repeats, *args)
|
||||
base_shape = _pad_left(self.shape, repeats)[0]
|
||||
base_shape = _align_left(self.shape, repeats)[0]
|
||||
unsqueezed_shape = flatten([[1, s] for s in base_shape])
|
||||
expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)])
|
||||
final_shape = [r*s for r,s in zip(repeats, base_shape)]
|
||||
@@ -1885,7 +1885,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.argmin(axis=1).numpy()) # Returns the indices of the minimum values along axis 1.
|
||||
```
|
||||
"""
|
||||
return (-self).argmax(axis=axis, keepdim=keepdim)
|
||||
return (-self if self.is_floating_point() else ~self).argmax(axis=axis, keepdim=keepdim)
|
||||
|
||||
def rearrange(self, formula: str, **sizes) -> Tensor:
|
||||
"""
|
||||
@@ -2002,11 +2002,27 @@ class Tensor(SimpleMathTrait):
|
||||
def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]:
|
||||
return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1])
|
||||
|
||||
def _ceil_mode_padding2d(self,k_:Tuple[sint, ...], s_:Union[Tuple[int, ...], int], d_:Union[Tuple[int, ...], int],
|
||||
p_:Union[Tuple[int, ...], int]) -> Sequence[int]:
|
||||
(d_,s_,p_), i_ = (make_tuple(x, len(k_)) for x in (d_,s_,p_)), self.shape[-len(k_):]
|
||||
# https://arxiv.org/pdf/1603.07285 section 5.1, relationship 15.
|
||||
o_ = [ceildiv(i+2*p - (d*(k-1)+1), s) + 1 for i,d,k,s,p in zip(i_,d_,k_,s_,p_)]
|
||||
pads = list(self._padding2d(p_, len(k_)))
|
||||
# we have to do additional padding before `_pool` so that `o_` in `_pool` is calculated correctly
|
||||
# `s*(o-1) + (d*(k-1)+1) - (i+2*p)` -> last_sliding_window_start + full_kernel_size - padded_input_shape
|
||||
# we decrease padding in the case that a sliding window starts in the end padded region, thereby decreasing `o_` in `_pool`
|
||||
# `smax(s*(o-1) - (p+i-1), 0)` -> last_sliding_window_start - (left_pad + input_size - zero_offset)
|
||||
for dim,(o,i,s,p,k,d) in enumerate(zip(o_,i_,s_,p_,k_,d_)): pads[-1-dim*2] += s*(o-1) + (d*(k-1)+1) - (i+2*p) - smax(s*(o-1) - (p+i-1), 0)
|
||||
return pads
|
||||
|
||||
# NOTE: these work for more than 2D
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, count_include_pad=True):
|
||||
def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
"""
|
||||
Applies average pooling over a tensor.
|
||||
|
||||
When `ceil_mode` is set to True, output shape will be determined using ceil division.
|
||||
When `count_include_pad` is set to False, zero padding will not be included in the averaging calculation.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
||||
|
||||
See: https://paperswithcode.com/method/average-pooling
|
||||
@@ -2016,17 +2032,30 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.avg_pool2d().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(ceil_mode=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(padding=1).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.avg_pool2d(padding=1, count_include_pad=False).numpy())
|
||||
```
|
||||
"""
|
||||
padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0))
|
||||
def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis)
|
||||
axis = tuple(range(-len(k_ := make_tuple(kernel_size, 2)), 0))
|
||||
reg_pads, ceil_pads = self._padding2d(padding,len(k_)), self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding)
|
||||
def pool(x:Tensor, padding_:Sequence[int]) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation)
|
||||
if not count_include_pad:
|
||||
pads = ceil_pads if ceil_mode else reg_pads
|
||||
return pool(self, pads).sum(axis) / pool(self.ones_like(), pads).sum(axis)
|
||||
if not ceil_mode: return pool(self, reg_pads).mean(axis)
|
||||
return pool(self, ceil_pads).sum(axis) / pool(self.pad(reg_pads).ones_like(), tuple(cp-rp for cp,rp in zip(ceil_pads, reg_pads))).sum(axis)
|
||||
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0):
|
||||
def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0, ceil_mode=False):
|
||||
"""
|
||||
Applies max pooling over a tensor.
|
||||
|
||||
When `ceil_mode` is set to True, output shape will be determined using ceil division.
|
||||
|
||||
NOTE: unlike PyTorch, this implementation is not limited to only 2d pooling and instead works for any number of dimensions.
|
||||
|
||||
See: https://paperswithcode.com/method/max-pooling
|
||||
@@ -2036,11 +2065,15 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.max_pool2d().numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max_pool2d(ceil_mode=True).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.max_pool2d(padding=1).numpy())
|
||||
```
|
||||
"""
|
||||
padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2)))
|
||||
return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
k_ = make_tuple(kernel_size, 2)
|
||||
pads = self._ceil_mode_padding2d(k_, stride if stride is not None else k_, dilation, padding) if ceil_mode else self._padding2d(padding, len(k_))
|
||||
return self.pad(pads, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0)))
|
||||
|
||||
def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0,
|
||||
acc_dtype:Optional[DTypeLike]=None) -> Tensor:
|
||||
@@ -2336,10 +2369,14 @@ class Tensor(SimpleMathTrait):
|
||||
index, dim = index.to(self.device), self._resolve_dim(dim)
|
||||
src = src.cast(self.dtype) if isinstance(src, Tensor) else Tensor(src, device=self.device, dtype=self.dtype)._broadcast_to(index.shape)
|
||||
assert index.ndim == self.ndim == src.ndim, f"self.ndim, index.ndim and src.dim must all equal, {self.ndim=} {index.ndim=} {src.ndim=}"
|
||||
assert all((d == dim or se >= ind) and sr >= ind for d,(se,ind,sr) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
||||
assert all((d == dim or self_ >= index_) and src_ >= index_ for d,(self_,index_,src_) in enumerate(zip(self.shape, index.shape, src.shape))), \
|
||||
f"All dimensions of {index.shape=} should be <= to all dimensions of {src.shape=} and all dimensions except dimension {dim} of {self.shape=}"
|
||||
mask = (index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)).transpose(-1, dim)
|
||||
src = src.unsqueeze(-1).expand((None,)*src.ndim + (self.shape[dim],)).transpose(-1, dim).shrink(tuple((0,s) for s in mask.shape))
|
||||
# shrink src to index shape to shrink away the unused values
|
||||
src = src.shrink(tuple((0,s) for s in index.shape))
|
||||
# prepare src and mask for reduce with respect to dim
|
||||
src = src.unsqueeze(-1).expand(*src.shape, self.shape[dim]).transpose(-1, dim)
|
||||
mask = index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).transpose(-1, dim)
|
||||
# pad src and mask to self.shape so that reduce can be done with padded values as no-ops
|
||||
src, mask = (x.pad(tuple((0, self.shape[i] - x.shape[i]) if i != dim else None for i in range(self.ndim)) + (None,)) for x in (src, mask))
|
||||
if reduce == "add": return mask.where(src, 0).sum(-1, acc_dtype=self.dtype) + self
|
||||
if reduce == "multiply": return mask.where(src, 1).prod(-1, acc_dtype=self.dtype) * self
|
||||
@@ -2928,15 +2965,15 @@ class Tensor(SimpleMathTrait):
|
||||
return self / (1 + self.abs())
|
||||
|
||||
# ***** broadcasted elementwise ops *****
|
||||
def _broadcast_to(self, shape:Tuple[sint, ...]) -> Tensor:
|
||||
if self.shape == shape: return self
|
||||
if self.ndim > len(shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {shape=}")
|
||||
# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
||||
padded, _ = _pad_left(self.shape, shape)
|
||||
# for each dimension, check either from_ is 1, or it does not change
|
||||
if any(resolve(from_ != 1, False) and resolve(from_ != to, False) for from_,to in zip(padded, shape)):
|
||||
raise ValueError(f"cannot broadcast from shape={self.shape} to {shape=}")
|
||||
return F.Expand.apply(self.reshape(padded), shape=shape)
|
||||
def _broadcast_to(self, new_shape:Tuple[sint, ...]) -> Tensor:
|
||||
if self.shape == new_shape: return self
|
||||
if self.ndim > len(new_shape): raise ValueError(f"cannot broadcast tensor to fewer dimensions. shape={self.shape} to {new_shape=}")
|
||||
# first unsqueeze left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html
|
||||
shape, _ = _align_left(self.shape, new_shape)
|
||||
# for each dimension, check either dim is 1, or it does not change
|
||||
if not all(resolve(s == ns) or resolve(s == 1) for s,ns in zip(shape, new_shape)):
|
||||
raise ValueError(f"cannot broadcast {self.shape} to {new_shape=}")
|
||||
return F.Expand.apply(self.reshape(shape), shape=new_shape)
|
||||
|
||||
def _broadcasted(self, y:Union[Tensor, UOp, ConstType], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]:
|
||||
x: Tensor = self
|
||||
@@ -2955,11 +2992,10 @@ class Tensor(SimpleMathTrait):
|
||||
if reverse: x, y = y, x
|
||||
|
||||
# broadcast
|
||||
out_shape = _broadcast_shape(x.shape, y.shape)
|
||||
return x._broadcast_to(out_shape), y._broadcast_to(out_shape)
|
||||
return x._broadcast_to(out_shape:=_broadcast_shape(x.shape, y.shape)), y._broadcast_to(out_shape)
|
||||
|
||||
def _to_const_val(self, x:Union[Tensor, ConstType]) -> Union[Tensor, ConstType]:
|
||||
return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
||||
return x.lazydata.const_arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_unmasked_const() \
|
||||
and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x
|
||||
|
||||
def add(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
@@ -3116,7 +3152,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if self.dtype != dtypes.bool and not dtypes.is_int(self.dtype): raise RuntimeError(f"{self.dtype} is not supported")
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ ((1<<8*self.dtype.itemsize)-1)
|
||||
return self.logical_not() if self.dtype == dtypes.bool else self ^ -1
|
||||
|
||||
def lshift(self, x:int):
|
||||
"""
|
||||
@@ -3160,8 +3196,9 @@ class Tensor(SimpleMathTrait):
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return self.reciprocal().pow(-x)
|
||||
if x < 0: return self.reciprocal().pow(-x).cast(self.dtype)
|
||||
if x == 0: return 1 + self * 0
|
||||
# rewrite pow 0.5 to sqrt
|
||||
if int(x - 0.5) + 0.5 == x: return self.pow(int(x - 0.5)) * self.sqrt()
|
||||
if int(x) == x: return self.pow(x // 2).square() * (1 if x % 2 == 0 else self)
|
||||
|
||||
@@ -3178,7 +3215,8 @@ class Tensor(SimpleMathTrait):
|
||||
# inject nan for negative base and non-integer exponent
|
||||
inject_nan = (negative_base * (exponent != exponent.trunc())).detach().where(math.nan, 1)
|
||||
# apply correct_sign inject_nan, and fix 0 ** 0 = 1
|
||||
return ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
||||
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * correct_sign * inject_nan)
|
||||
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
|
||||
|
||||
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor:
|
||||
"""
|
||||
@@ -3354,6 +3392,11 @@ class Tensor(SimpleMathTrait):
|
||||
if not Tensor.training or p == 0: return self
|
||||
return (Tensor.rand_like(self, requires_grad=False, dtype=dtypes.default_float, contiguous=False) >= p).contiguous().where(self, 0) / (1.0 - p)
|
||||
|
||||
# helper function commonly used for indexing
|
||||
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
||||
offset = self.ndim - self._resolve_dim(dim) - 1
|
||||
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
||||
|
||||
def one_hot(self, num_classes:int=-1) -> Tensor:
|
||||
"""
|
||||
Converts `self` to a one-hot tensor.
|
||||
@@ -3366,7 +3409,7 @@ class Tensor(SimpleMathTrait):
|
||||
```
|
||||
"""
|
||||
if num_classes == -1: num_classes = (self.max()+1).item()
|
||||
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None,
|
||||
dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
|
||||
@@ -3442,8 +3485,8 @@ class Tensor(SimpleMathTrait):
|
||||
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
||||
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
y_counted = Y.to(self.device).flatten().reshape(-1, 1)._one_hot_along_dim(self.shape[-1])
|
||||
y = (y_counted * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
||||
unreduced = ((1 - label_smoothing) * (log_probs * y).sum(-1) + smoothing)
|
||||
# NOTE: because of ignore_index, we can't use Tensor.mean (so can't use `_do_reduction` here)
|
||||
@@ -3590,8 +3633,15 @@ class Tensor(SimpleMathTrait):
|
||||
t = t.cast(dtypes.int32)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = t.cast(dtypes.uint8)
|
||||
print(t.dtype, t.numpy())
|
||||
```
|
||||
"""
|
||||
return self if self.dtype == (dt:=to_dtype(dtype)) else F.Cast.apply(self, dtype=dt)
|
||||
if (dt:=to_dtype(dtype)) in {dtypes.uint8, dtypes.uint16} and dtypes.is_float(self.dtype):
|
||||
# NOTE: values within the int32 range and outside the unsigned dtype range will cause values to wrap around
|
||||
return F.Cast.apply(F.Cast.apply(self, dtype=dtypes.int32), dtype=dt)
|
||||
return self if self.dtype == dt else F.Cast.apply(self, dtype=dt)
|
||||
|
||||
def bitcast(self, dtype:DTypeLike) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user