mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 06:34:03 -05:00
Merge origin/master
This commit is contained in:
306
.github/workflows/benchmark.yml
vendored
306
.github/workflows/benchmark.yml
vendored
@@ -49,19 +49,19 @@ jobs:
|
||||
- name: Print macOS version
|
||||
run: sw_vers
|
||||
- name: Run Stable Diffusion
|
||||
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run Stable Diffusion without fp16
|
||||
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing
|
||||
- name: Run Stable Diffusion v2
|
||||
# TODO: very slow step time
|
||||
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=4500 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=4500 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing
|
||||
# process replay can't capture this, the graph is too large
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=5000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=5000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run model inference benchmark
|
||||
run: METAL=1 NOCLANG=1 python3.11 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
run: BIG=2 MPS=1 python3.11 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
|
||||
run: BIG=2 MPS=1 python3.11 test/speed/external_test_speed_v_torch.py
|
||||
- name: Test tensor cores
|
||||
run: METAL=1 python3.11 test/opt/test_tensor_cores.py
|
||||
- name: Test AMX tensor cores
|
||||
@@ -71,84 +71,59 @@ jobs:
|
||||
DEBUG=2 CPU=1 CPU_LLVM=0 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
DEBUG=2 CPU=1 CPU_LLVM=1 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
- name: Run Tensor Core GEMM (float)
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (half)
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 HALF=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (bfloat16)
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
|
||||
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
|
||||
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
|
||||
- name: Run quantized LLaMA3
|
||||
run: |
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8 | tee llama3_int8.txt
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4 | tee llama3_nf4.txt
|
||||
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
|
||||
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
|
||||
#- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run OLMoE
|
||||
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
|
||||
|
||||
# NOTE: this is failing in CI. it is not failing on my machine and I don't really have a way to debug it
|
||||
# the error is "RuntimeError: Internal Error (0000000e:Internal Error)"
|
||||
#- name: Run 10 CIFAR training steps
|
||||
# run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=3000 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=3000 STEPS=10 python3.11 examples/hlb_cifar10.py
|
||||
#- name: Run 10 CIFAR training steps w HALF
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=3000 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=3000 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py
|
||||
|
||||
#- name: Run 10 CIFAR training steps w BF16
|
||||
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py
|
||||
# TODO: too slow
|
||||
# - name: Run 10 CIFAR training steps w winograd
|
||||
# run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (Mac)
|
||||
path: |
|
||||
onnx_inference_speed.csv
|
||||
torch_speed.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
llama_int8.txt
|
||||
llama_nf4.txt
|
||||
llama3_int8.txt
|
||||
llama3_nf4.txt
|
||||
llama_four_gpu.txt
|
||||
gpt2_unjitted.txt
|
||||
gpt2_jitted.txt
|
||||
gpt2_half.txt
|
||||
gpt2_half_beam.txt
|
||||
matmul.txt
|
||||
matmul_half.txt
|
||||
matmul_bfloat16.txt
|
||||
sd.txt
|
||||
sd_no_fp16.txt
|
||||
sdv2.txt
|
||||
sdxl.txt
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_half.txt
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3.11 process_replay.py
|
||||
|
||||
@@ -215,7 +190,7 @@ jobs:
|
||||
- name: Run model inference benchmark
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 NOCLANG=1 python3 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
- name: Test speed vs theoretical
|
||||
run: NV=1 IGNORE_BEAM_CACHE=1 CCACHE=0 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
|
||||
- name: Test benchmark allreduce
|
||||
@@ -226,79 +201,58 @@ jobs:
|
||||
NV=1 NV_PTX=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
|
||||
- name: Run Tensor Core GEMM (CUDA)
|
||||
run: |
|
||||
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 FP8E4M3=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_fp8.txt
|
||||
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
|
||||
CUDA=1 SHOULD_USE_TC=1 FP8E4M3=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (PTX)
|
||||
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (NV)
|
||||
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
|
||||
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Test NV=1
|
||||
run: DEBUG=2 NV=1 python -m pytest -rA test/test_tiny.py
|
||||
- name: Test CUDA=1
|
||||
run: DEBUG=2 CUDA=1 python -m pytest -rA test/test_tiny.py
|
||||
- name: Run Stable Diffusion
|
||||
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
# TODO: too slow
|
||||
# - name: Run SDXL
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
BENCHMARK_LOG=llama NV=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama NV=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
run: BENCHMARK_LOG=llama_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run quantized LLaMA3
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8 | tee llama3_fp8.txt
|
||||
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
run: time BENCHMARK_LOG=mixtral NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
path: |
|
||||
onnx_inference_speed.csv
|
||||
torch_speed.txt
|
||||
matmul.txt
|
||||
matmul_bfloat16.txt
|
||||
matmul_tf32.txt
|
||||
matmul_ptx.txt
|
||||
matmul_nv.txt
|
||||
sd.txt
|
||||
sdxl.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
llama3_beam.txt
|
||||
llama3_four_gpu.txt
|
||||
llama3_six_gpu.txt
|
||||
llama3_fp8.txt
|
||||
llama_2_70B.txt
|
||||
mixtral.txt
|
||||
gpt2_unjitted.txt
|
||||
gpt2_jitted.txt
|
||||
gpt2_half.txt
|
||||
gpt2_half_beam.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -339,42 +293,28 @@ jobs:
|
||||
- name: HEVC Decode Benchmark
|
||||
run: VALIDATE=1 MAX_FRAMES=100 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
|
||||
- name: Run 10 CIFAR training steps
|
||||
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 python3 examples/hlb_cifar10.py
|
||||
- name: Run 10 CIFAR training steps w HALF
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=110 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=110 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
- name: Run 10 CIFAR training steps w BF16
|
||||
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
|
||||
# - name: Run 10 CIFAR training steps w winograd
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=350 NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=350 NV=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training w 1 GPU
|
||||
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training steps w 6 GPUS
|
||||
run: time BENCHMARK_LOG=cifar_6gpu CAPTURE_PROCESS_REPLAY=0 NV=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar_6gpu CAPTURE_PROCESS_REPLAY=0 NV=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run MLPerf resnet eval on training data
|
||||
run: time BENCHMARK_LOG=resnet_eval NV=1 MODEL=resnet python3 examples/mlperf/model_eval.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
|
||||
run: BENCHMARK_LOG=resnet_10steps NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
run: BENCHMARK_LOG=resnet_10steps NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
|
||||
run: BENCHMARK_LOG=resnet_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
run: BENCHMARK_LOG=resnet_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf Bert training steps (6 gpu)
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: BENCHMARK_LOG=bert_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NVIDIA Training)
|
||||
path: |
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_half.txt
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
train_cifar_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
train_resnet.txt
|
||||
train_resnet_one_gpu.txt
|
||||
train_bert.txt
|
||||
run: BENCHMARK_LOG=bert_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -426,7 +366,7 @@ jobs:
|
||||
#- name: Test speed vs torch
|
||||
# run: |
|
||||
# python3 -c "import torch; print(torch.__version__)"
|
||||
# LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
|
||||
# LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
- name: Test speed vs theoretical
|
||||
run: AMD=1 IGNORE_BEAM_CACHE=1 CCACHE=0 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
|
||||
- name: Test tensor cores AMD_LLVM=0
|
||||
@@ -437,7 +377,7 @@ jobs:
|
||||
- name: Run Tensor Core GEMM (AMD)
|
||||
run: |
|
||||
AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
|
||||
AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Test AMD=1
|
||||
run: DEBUG=2 AMD=1 python -m pytest -rA test/test_tiny.py
|
||||
#- name: Test HIP=1
|
||||
@@ -452,61 +392,39 @@ jobs:
|
||||
- name: Test AM warm start time
|
||||
run: time AMD=1 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run Stable Diffusion
|
||||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
|
||||
- name: Run SDXL
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
BENCHMARK_LOG=llama AMD=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=llama AMD=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA 7B with BEAM
|
||||
run: BENCHMARK_LOG=llama_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
run: BENCHMARK_LOG=llama_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 4 GPUs
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
# - name: Run LLaMA 7B on 6 GPUs
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
run: BENCHMARK_LOG=llama3_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
run: BENCHMARK_LOG=llama3_beam_4gpu AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
|
||||
#- name: Restore amdgpu
|
||||
# run: sudo modprobe amdgpu
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time BENCHMARK_LOG=mixtral AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
run: time BENCHMARK_LOG=mixtral AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF
|
||||
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD)
|
||||
path: |
|
||||
onnx_inference_speed.csv
|
||||
torch_speed.txt
|
||||
llama_unjitted.txt
|
||||
llama_jitted.txt
|
||||
llama_beam.txt
|
||||
llama3_beam.txt
|
||||
llama3_four_gpu.txt
|
||||
llama3_six_gpu.txt
|
||||
llama_2_70B.txt
|
||||
gpt2_unjitted.txt
|
||||
gpt2_jitted.txt
|
||||
gpt2_half.txt
|
||||
gpt2_half_beam.txt
|
||||
matmul.txt
|
||||
matmul_amd.txt
|
||||
sd.txt
|
||||
sdxl.txt
|
||||
mixtral.txt
|
||||
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -543,31 +461,20 @@ jobs:
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
|
||||
- name: Run 10 CIFAR training steps
|
||||
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py
|
||||
- name: Run 10 CIFAR training steps w HALF
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
# - name: Run 10 CIFAR training steps w BF16
|
||||
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
|
||||
# TODO: too slow
|
||||
# - name: Run 10 CIFAR training steps w winograd
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training w 1 GPU
|
||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run full CIFAR training steps w 6 GPUS
|
||||
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD Training)
|
||||
path: |
|
||||
beautiful_mnist.txt
|
||||
train_cifar.txt
|
||||
train_cifar_half.txt
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
train_cifar_one_gpu.txt
|
||||
train_cifar_six_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -606,19 +513,12 @@ jobs:
|
||||
- name: Run MLPerf resnet eval
|
||||
run: time BENCHMARK_LOG=resnet_eval AMD=1 MODEL=resnet python3 examples/mlperf/model_eval.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
|
||||
run: BENCHMARK_LOG=resnet_10steps AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
|
||||
run: BENCHMARK_LOG=resnet_10steps AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
|
||||
run: BENCHMARK_LOG=resnet_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
|
||||
run: BENCHMARK_LOG=resnet_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf Bert training steps (6 gpu)
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: BENCHMARK_LOG=bert_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD MLPerf)
|
||||
path: |
|
||||
train_resnet.txt
|
||||
train_resnet_one_gpu.txt
|
||||
train_bert.txt
|
||||
run: BENCHMARK_LOG=bert_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -708,7 +608,7 @@ jobs:
|
||||
# AMD=1 AMD_LLVM=1 python3 test/test_linearizer.py test/opt/test_tensor_cores.py
|
||||
# AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Run Tensor Core GEMM (AMD)
|
||||
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee am_matmul_amd.txt
|
||||
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Test AMD=1
|
||||
run: DEBUG=2 AMD=1 python -m pytest -rA test/test_tiny.py
|
||||
- name: Test DISK copy time
|
||||
@@ -718,20 +618,12 @@ jobs:
|
||||
AMD=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyDefaulttoCPUJit
|
||||
AMD=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyCPUtoDefaultJit
|
||||
- name: Run full CIFAR training w 1 GPU
|
||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee am_train_cifar_one_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
# - name: Run 10 MLPerf ResNet50 training steps (1 gpu)
|
||||
# run: BENCHMARK_LOG=resnet_10steps AMD=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee am_train_resnet_one_gpu.txt
|
||||
# run: BENCHMARK_LOG=resnet_10steps AMD=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf Bert training steps (1 gpu)
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: BENCHMARK_LOG=bert_10steps AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee am_train_bert_one_gpu.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AM Driver)
|
||||
path: |
|
||||
am_matmul_amd.txt
|
||||
am_train_cifar_one_gpu.txt
|
||||
am_train_resnet_one_gpu.txt
|
||||
am_train_bert_one_gpu.txt
|
||||
run: BENCHMARK_LOG=bert_10steps AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
@@ -778,21 +670,13 @@ jobs:
|
||||
NV=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyDefaulttoCPUJit
|
||||
NV=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyCPUtoDefaultJit
|
||||
- name: Test LLAMA-3
|
||||
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --benchmark --temperature 0 | tee nv_llama3_beam.txt
|
||||
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --benchmark --temperature 0
|
||||
- name: Run full CIFAR training w 1 GPU
|
||||
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee nv_train_cifar_one_gpu.txt
|
||||
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
|
||||
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
|
||||
run: BENCHMARK_LOG=resnet_10steps NV=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee nv_train_resnet_one_gpu.txt
|
||||
run: BENCHMARK_LOG=resnet_10steps NV=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
|
||||
- name: Run 10 MLPerf Bert training steps (1 gpu)
|
||||
# TODO: remove BERT_LAYERS once scheduler is fast
|
||||
run: BENCHMARK_LOG=bert_10steps NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee nv_train_bert_one_gpu.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NV Driver)
|
||||
path: |
|
||||
nv_llama3_beam.txt
|
||||
nv_train_cifar_one_gpu.txt
|
||||
nv_train_resnet_one_gpu.txt
|
||||
nv_train_bert_one_gpu.txt
|
||||
run: BENCHMARK_LOG=bert_10steps NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
||||
19
.github/workflows/test.yml
vendored
19
.github/workflows/test.yml
vendored
@@ -5,6 +5,7 @@ env:
|
||||
CAPTURE_PROCESS_REPLAY: 1
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
IGNORE_OOB: 0
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -36,6 +37,8 @@ jobs:
|
||||
name: Docs
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
IGNORE_OOB: 1
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -215,7 +218,6 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
|
||||
# TODO: run the pre-commit hook to replace a lot of this
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -230,13 +232,13 @@ jobs:
|
||||
- name: Lint with ruff
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff==0.14.10
|
||||
python3 -m ruff check .
|
||||
pre-commit run ruff --all-files
|
||||
python3 -m ruff check examples/mlperf/ --ignore E501
|
||||
python3 -m ruff check extra/thunder/tiny/ --ignore E501 --ignore F841 --ignore E722
|
||||
python3 -m ruff check extra/torch_backend/backend.py
|
||||
- name: Run mypy
|
||||
run: |
|
||||
python -m mypy --strict-equality --lineprecision-report .
|
||||
python -m mypy --lineprecision-report .
|
||||
cat lineprecision.txt
|
||||
- name: Run TYPED=1
|
||||
run: TYPED=1 python -c "import tinygrad"
|
||||
@@ -307,7 +309,7 @@ jobs:
|
||||
deps: testing_unit
|
||||
python-version: '3.14'
|
||||
- name: Test SPEC=2
|
||||
run: IGNORE_OOB=0 SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
|
||||
|
||||
fuzzing:
|
||||
name: Fuzzing
|
||||
@@ -470,6 +472,8 @@ jobs:
|
||||
name: Test LLM
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 15
|
||||
env:
|
||||
IGNORE_OOB: 1
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -584,8 +588,6 @@ jobs:
|
||||
name: Linux (WebGPU)
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 20
|
||||
env:
|
||||
IGNORE_OOB: 0
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -822,7 +824,6 @@ jobs:
|
||||
NV_PTX: 1
|
||||
NV: 1
|
||||
FORWARD_ONLY: 1
|
||||
IGNORE_OOB: 0
|
||||
run: |
|
||||
python3 -m pytest -n=auto test/device/test_hcq.py test/test_tiny.py --durations=20
|
||||
- name: Run process replay tests
|
||||
@@ -832,8 +833,6 @@ jobs:
|
||||
name: MacOS (WebGPU)
|
||||
runs-on: macos-14
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
IGNORE_OOB: 0
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
@@ -911,8 +910,6 @@ jobs:
|
||||
name: Windows (${{ matrix.backend }})
|
||||
runs-on: windows-latest
|
||||
timeout-minutes: 15
|
||||
env:
|
||||
IGNORE_OOB: 0
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -16,7 +16,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: mypy
|
||||
name: mypy
|
||||
entry: python3 -m mypy tinygrad/ --strict-equality
|
||||
entry: python3 -m mypy
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,48 +1,56 @@
|
||||
# library for RDNA3 assembly DSL
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import struct, math
|
||||
import struct, math, re
|
||||
from enum import IntEnum
|
||||
from functools import cache, cached_property
|
||||
from typing import overload, Annotated, TypeVar, Generic
|
||||
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
|
||||
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
|
||||
|
||||
# Common masks and bit conversion functions
|
||||
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
|
||||
def _f32(i): return struct.unpack("<f", struct.pack("<I", i & MASK32))[0]
|
||||
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
|
||||
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
|
||||
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
|
||||
def _f32(i): return _struct_f.unpack(_struct_I.pack(i & MASK32))[0]
|
||||
def _i32(f):
|
||||
if isinstance(f, int): f = float(f)
|
||||
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
|
||||
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
|
||||
try: return struct.unpack("<I", struct.pack("<f", f))[0]
|
||||
try: return _struct_I.unpack(_struct_f.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
|
||||
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
|
||||
def _f16(i): return struct.unpack("<e", struct.pack("<H", i & 0xffff))[0]
|
||||
def _f16(i): return _struct_e.unpack(_struct_H.pack(i & 0xffff))[0]
|
||||
def _i16(f):
|
||||
if math.isnan(f): return 0x7e00
|
||||
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
|
||||
try: return struct.unpack("<H", struct.pack("<e", f))[0]
|
||||
try: return _struct_H.unpack(_struct_e.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
|
||||
def _f64(i): return struct.unpack("<d", struct.pack("<Q", i & MASK64))[0]
|
||||
def _f64(i): return _struct_d.unpack(_struct_Q.pack(i & MASK64))[0]
|
||||
def _i64(f):
|
||||
if math.isnan(f): return 0x7ff8000000000000
|
||||
if math.isinf(f): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
|
||||
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
|
||||
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
|
||||
|
||||
# Instruction spec - register counts and dtypes derived from instruction names
|
||||
import re
|
||||
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
|
||||
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
|
||||
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
|
||||
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
|
||||
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
|
||||
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
|
||||
@cache
|
||||
def _suffix(name: str) -> tuple[str | None, str | None]:
|
||||
name = name.upper()
|
||||
if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
# Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc.
|
||||
if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
|
||||
if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1)
|
||||
if m := _CVT_RE.search(name): return m.group(1), m.group(2)
|
||||
if m := _MAD_MUL_RE.search(name): return m.group(1), m.group(2)
|
||||
if m := _PACK_RE.search(name): return m.group(1), m.group(2)
|
||||
if m := _DST_SRC_RE.search(name): return m.group(1), m.group(2)
|
||||
if m := _SINGLE_RE.search(name): return m.group(1), m.group(1)
|
||||
return None, None
|
||||
_SPECIAL_REGS = {
|
||||
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
|
||||
@@ -75,27 +83,33 @@ _SPECIAL_DTYPE = {
|
||||
# RDNA4 CVT_PK_F32 instructions: source is 8-bit packed as 16-bit operand
|
||||
'V_CVT_PK_F32_BF8': ('F32', 'B16', None, None), 'V_CVT_PK_F32_FP8': ('F32', 'B16', None, None),
|
||||
}
|
||||
@cache
|
||||
def spec_regs(name: str) -> tuple[int, int, int, int]:
|
||||
name = name.upper()
|
||||
if name in _SPECIAL_REGS: return _SPECIAL_REGS[name]
|
||||
if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1
|
||||
uname = name.upper()
|
||||
if uname in _SPECIAL_REGS: return _SPECIAL_REGS[uname]
|
||||
if 'SAD' in uname and 'U8' in uname and 'QSAD' not in uname and 'MQSAD' not in uname: return 1, 1, 1, 1
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
|
||||
@cache
|
||||
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
|
||||
name = name.upper()
|
||||
if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name]
|
||||
if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32'
|
||||
if '_CMP_' in name or '_CMPX_' in name:
|
||||
uname = name.upper()
|
||||
if uname in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[uname]
|
||||
if 'SAD' in uname and ('U8' in uname or 'U16' in uname) and 'QSAD' not in uname and 'MQSAD' not in uname: return 'U32', 'U32', 'U32', 'U32'
|
||||
if '_CMP_' in uname or '_CMPX_' in uname:
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None
|
||||
return 'EXEC' if '_CMPX_' in uname else 'VCC', src_suf, src_suf, None
|
||||
dst_suf, src_suf = _suffix(name)
|
||||
return dst_suf, src_suf, src_suf, src_suf
|
||||
_F16_RE = re.compile(r'_[FIUB]16(?:_|$)')
|
||||
_F64_RE = re.compile(r'_[FIUB]64(?:_|$)')
|
||||
@cache
|
||||
def spec_is_16bit(name: str) -> bool:
|
||||
name = name.upper()
|
||||
if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False
|
||||
if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16
|
||||
return bool(re.search(r'_[FIUB]16(?:_|$)', name))
|
||||
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
|
||||
uname = name.upper()
|
||||
if 'SAD' in uname or 'PACK' in uname or '_PK_' in uname or 'SAT_PK' in uname or 'DOT2' in uname: return False
|
||||
if '_F32' in uname or '_I32' in uname or '_U32' in uname or '_B32' in uname: return False
|
||||
return bool(_F16_RE.search(uname))
|
||||
@cache
|
||||
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
|
||||
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
|
||||
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
|
||||
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
|
||||
@@ -535,21 +549,25 @@ class Inst:
|
||||
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
|
||||
return self._enum_map[cls_name](val)
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def op_name(self) -> str:
|
||||
op = self.op
|
||||
return op.name if hasattr(op, 'name') else ''
|
||||
|
||||
def dst_regs(self) -> int: return spec_regs(self.op_name)[0]
|
||||
def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1]
|
||||
@cached_property
|
||||
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
|
||||
@cached_property
|
||||
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
|
||||
def dst_regs(self) -> int: return self._spec_regs[0]
|
||||
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
|
||||
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
|
||||
def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0]
|
||||
def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1]
|
||||
def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n))
|
||||
def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2
|
||||
def dst_dtype(self) -> str | None: return self._spec_dtype[0]
|
||||
def src_dtype(self, n: int) -> str | None: return self._spec_dtype[n + 1]
|
||||
def is_src_16(self, n: int) -> bool: return self._spec_regs[n + 1] == 1 and is_dtype_16(self._spec_dtype[n + 1])
|
||||
def is_src_64(self, n: int) -> bool: return self._spec_regs[n + 1] == 2
|
||||
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
|
||||
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
|
||||
def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype())
|
||||
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
|
||||
|
||||
class Inst32(Inst): pass
|
||||
class Inst64(Inst): pass
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
import ctypes, struct
|
||||
from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
import ctypes
|
||||
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.asm import detect_format
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
|
||||
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
|
||||
@@ -178,24 +179,21 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
|
||||
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
|
||||
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
|
||||
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
|
||||
exec_mask = st.exec_mask
|
||||
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
|
||||
|
||||
# Execute compiled function - pass PC in bytes for instructions that need it
|
||||
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
|
||||
pc_bytes = st.pc * 4
|
||||
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
|
||||
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
|
||||
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
|
||||
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
|
||||
|
||||
# Apply results
|
||||
if sdst is not None:
|
||||
(st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0'])
|
||||
if 'scc' in result: st.scc = result['scc']
|
||||
if 'exec' in result: st.exec_mask = result['exec']
|
||||
if 'new_pc' in result:
|
||||
# Apply results - extract values from returned Reg objects
|
||||
if sdst is not None and 'D0' in result:
|
||||
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
|
||||
if 'SCC' in result: st.scc = result['SCC']._val & 1
|
||||
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
|
||||
if 'PC' in result:
|
||||
# Convert absolute byte address to word delta
|
||||
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
|
||||
new_pc_words = result['new_pc'] // 4
|
||||
pc_val = result['PC']._val
|
||||
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
|
||||
new_pc_words = new_pc // 4
|
||||
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
|
||||
return 0
|
||||
|
||||
@@ -260,24 +258,25 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
|
||||
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
|
||||
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
|
||||
results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'])
|
||||
for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))]
|
||||
for dst, val in results: V[dst] = val
|
||||
def exec_vopd(vopd_op, s0, s1, d0):
|
||||
op = _VOPD_TO_VOP[vopd_op]
|
||||
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
|
||||
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
|
||||
return
|
||||
|
||||
# VOP3SD: has extra scalar dest for carry output
|
||||
if isinstance(inst, VOP3SD):
|
||||
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
fn = compiled[VOP3SDOp][inst.op]
|
||||
# Read sources based on register counts from inst properties
|
||||
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
|
||||
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
|
||||
# Carry-in ops use src2 as carry bitmask instead of VCC
|
||||
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
|
||||
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
|
||||
V[inst.vdst] = result['d0'] & MASK32
|
||||
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
|
||||
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
d0_val = result['D0']._val
|
||||
V[inst.vdst] = d0_val & MASK32
|
||||
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
|
||||
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
return
|
||||
|
||||
# Get op enum and sources (None means "no source" for that operand)
|
||||
@@ -317,8 +316,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
|
||||
if neg & (1<<i): srcs[i] = -srcs[i]
|
||||
result = srcs[0] * srcs[1] + srcs[2]
|
||||
V = st.vgpr[lane]
|
||||
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
|
||||
return
|
||||
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
|
||||
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
|
||||
@@ -327,15 +325,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
|
||||
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
|
||||
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
|
||||
fn = compiled.get(VOP3POp, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
|
||||
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
|
||||
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
|
||||
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
|
||||
return
|
||||
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
|
||||
|
||||
op_cls = type(inst.op)
|
||||
fn = compiled.get(op_cls, {}).get(inst.op)
|
||||
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
|
||||
|
||||
# Read sources (with VOP3 modifiers if applicable)
|
||||
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
|
||||
@@ -377,24 +373,27 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
|
||||
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
|
||||
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
|
||||
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
|
||||
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
|
||||
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
|
||||
|
||||
# Apply results
|
||||
# Apply results - extract values from returned Reg objects
|
||||
if 'vgpr_write' in result:
|
||||
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
|
||||
wr_lane, wr_idx, wr_val = result['vgpr_write']
|
||||
st.vgpr[wr_lane][wr_idx] = wr_val
|
||||
if 'vcc_lane' in result:
|
||||
if 'VCC' in result:
|
||||
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
|
||||
if 'exec_lane' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
|
||||
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
|
||||
if 'EXEC' in result:
|
||||
# V_CMPX instructions write to EXEC per-lane (not to vdst)
|
||||
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
|
||||
elif op_cls is VOPCOp:
|
||||
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
|
||||
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
|
||||
if op_cls is not VOPCOp and 'vgpr_write' not in result:
|
||||
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
|
||||
d0_val = result['d0']
|
||||
d0_val = result['D0']._val
|
||||
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
|
||||
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
|
||||
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
|
||||
else: V[vdst] = d0_val & MASK32
|
||||
|
||||
|
||||
@@ -43,7 +43,10 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
|
||||
'CVT_OFF_TABLE', 'ThreadMask',
|
||||
'S1[i', 'C.i32', 'S[i]', 'in[',
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
|
||||
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
|
||||
'BARRIER_STATE', 'ReallocVgprs',
|
||||
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
|
||||
'fp6', 'bf6'] # Malformed pseudocode from PDF
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# COMPILER: pseudocode -> Python (minimal transforms)
|
||||
@@ -51,6 +54,7 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
|
||||
|
||||
def compile_pseudocode(pseudocode: str) -> str:
|
||||
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
|
||||
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
|
||||
raw_lines = pseudocode.strip().split('\n')
|
||||
joined_lines: list[str] = []
|
||||
for line in raw_lines:
|
||||
@@ -113,7 +117,7 @@ def compile_pseudocode(pseudocode: str) -> str:
|
||||
break
|
||||
else:
|
||||
lhs, rhs = line.split('=', 1)
|
||||
lhs_s, rhs_s = lhs.strip(), rhs.strip()
|
||||
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
|
||||
stmt = _assign(lhs_s, _expr(rhs_s))
|
||||
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
|
||||
stmt += "; break"
|
||||
@@ -556,52 +560,57 @@ def _apply_pseudocode_fixes(op, code: str) -> str:
|
||||
|
||||
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
|
||||
"""Generate a single compiled pseudocode function."""
|
||||
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
|
||||
has_d1 = '{ D1' in pc
|
||||
if has_d1: is_64 = True
|
||||
is_cmp = (cls_name in ('VOPCOp', 'VOP3Op')) and 'D0.u64[laneId]' in pc
|
||||
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
|
||||
is_div_scale = 'DIV_SCALE' in op.name
|
||||
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
|
||||
has_pc = 'PC' in pc
|
||||
combined = code + pc
|
||||
|
||||
fn_name = f"_{cls_name}_{op.name}"
|
||||
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):"]
|
||||
for pc_line in pc.split('\n'): lines.append(f" # {pc_line}")
|
||||
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
|
||||
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
|
||||
|
||||
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
|
||||
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
|
||||
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
|
||||
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
|
||||
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'), ('PC', 'Reg(pc)')]
|
||||
used = {name for name, _ in regs if name in combined}
|
||||
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
|
||||
if 'VCCZ' in combined: used.add('VCC')
|
||||
if 'EXECZ' in combined: used.add('EXEC')
|
||||
for name, init in regs:
|
||||
if name in used: lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code.split('\n'): lines.append(f" {line}")
|
||||
lines.append(" # --- end pseudocode ---")
|
||||
d0_val, scc_val = ("D0._val" if 'D0' in used else "d0"), ("SCC._val & 1" if 'SCC' in used else "scc & 1")
|
||||
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
|
||||
if has_sdst: lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
elif 'VCC' in used: lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
|
||||
if is_cmpx: lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
|
||||
elif 'EXEC' in used: lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
|
||||
if is_cmp: lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
|
||||
if is_64: lines.append(" result['d0_64'] = True")
|
||||
if has_d1: lines.append(" result['d1'] = D1._val & 1")
|
||||
if has_pc:
|
||||
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
|
||||
lines.append(" result['new_pc'] = _pc")
|
||||
lines.append(" return result\n")
|
||||
# Registers that need special handling (not passed directly)
|
||||
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
|
||||
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
|
||||
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
|
||||
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
|
||||
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
|
||||
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
|
||||
used = {name for name, _ in special_regs if name in combined}
|
||||
|
||||
# Detect which registers are modified (not just read) - look for assignments
|
||||
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
|
||||
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
|
||||
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
|
||||
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
|
||||
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
|
||||
|
||||
# Build init code for special registers
|
||||
init_lines = []
|
||||
if is_div_scale: init_lines.append(" D0 = Reg(S0._val)")
|
||||
for name, init in special_regs:
|
||||
if name in used: init_lines.append(f" {name} = {init}")
|
||||
if 'EXEC_LO' in code: init_lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
|
||||
if 'EXEC_HI' in code: init_lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
|
||||
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
|
||||
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
|
||||
code_lines = [line for line in code.split('\n') if line.strip()]
|
||||
if init_lines:
|
||||
lines.extend(init_lines)
|
||||
if code_lines: lines.append(" # --- compiled pseudocode ---")
|
||||
for line in code_lines:
|
||||
lines.append(f" {line}")
|
||||
|
||||
# Build result dict - only include registers that are modified
|
||||
result_items = []
|
||||
if modifies_d0: result_items.append("'D0': D0")
|
||||
if modifies_scc: result_items.append("'SCC': SCC")
|
||||
if modifies_vcc: result_items.append("'VCC': VCC")
|
||||
if modifies_exec: result_items.append("'EXEC': EXEC")
|
||||
if has_d1: result_items.append("'D1': D1")
|
||||
if modifies_pc: result_items.append("'PC': PC")
|
||||
lines.append(f" return {{{', '.join(result_items)}}}\n")
|
||||
return fn_name, '\n'.join(lines)
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
@@ -229,17 +229,18 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
"""Regression tests for pseudocode instruction emulation bugs."""
|
||||
|
||||
def test_v_div_scale_f32_vcc_always_returned(self):
|
||||
"""V_DIV_SCALE_F32 must always return vcc_lane, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), vcc_lane wasn't returned, so VCC bits weren't written.
|
||||
"""V_DIV_SCALE_F32 must always return VCC, even when VCC=0 (no scaling needed).
|
||||
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
|
||||
This caused division to produce wrong results for multiple lanes."""
|
||||
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
|
||||
s0 = 0x3f800000 # 1.0
|
||||
s1 = 0x40400000 # 3.0
|
||||
s2 = 0x3f800000 # 1.0 (numerator)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
# Must always have vcc_lane in result
|
||||
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
|
||||
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
|
||||
S0 = Reg(0x3f800000) # 1.0
|
||||
S1 = Reg(0x40400000) # 3.0
|
||||
S2 = Reg(0x3f800000) # 1.0 (numerator)
|
||||
D0, SCC, VCC, EXEC = Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
|
||||
result = _VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
# Must always have VCC in result
|
||||
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
|
||||
self.assertEqual(result['VCC']._val & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
|
||||
|
||||
def test_v_cmp_class_f32_detects_quiet_nan(self):
|
||||
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
|
||||
@@ -248,18 +249,22 @@ class TestPseudocodeRegressions(unittest.TestCase):
|
||||
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
|
||||
# Test quiet NaN detection (bit 1 in mask)
|
||||
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
S0, S1, S2, D0, SCC, VCC, EXEC = Reg(quiet_nan), Reg(s1_quiet), Reg(0), Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
|
||||
# Test signaling NaN detection (bit 0 in mask)
|
||||
s1_signal = 0b0000000001 # bit 0 = signaling NaN
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
S0, S1 = Reg(signal_nan), Reg(s1_signal)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
|
||||
# Test that quiet NaN doesn't match signaling NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
|
||||
S0, S1 = Reg(quiet_nan), Reg(s1_signal)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
|
||||
# Test that signaling NaN doesn't match quiet NaN mask
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
|
||||
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
|
||||
S0, S1 = Reg(signal_nan), Reg(s1_quiet)
|
||||
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
|
||||
self.assertEqual(result['D0']._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
|
||||
|
||||
def test_isnan_with_typed_view(self):
|
||||
"""_isnan must work with TypedView objects, not just Python floats.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// ** global buffers
|
||||
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
|
||||
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
|
||||
s_load_dwordx2 s[34:35], s[0:1], 0x08 // A
|
||||
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
|
||||
// ** others kernel args
|
||||
s_load_dword s24, s[0:1], 0x18 // N
|
||||
s_load_dword s54, s[0:1], 0x1C // num work groups
|
||||
@@ -52,7 +52,7 @@ def get_asm_prg() -> ProgramSpec:
|
||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
|
||||
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
||||
prg=CompiledRunner(get_asm_prg())))
|
||||
|
||||
with Context(DEBUG=2):
|
||||
@@ -1,12 +1,12 @@
|
||||
# unpack the complete kernel descriptor of an amdgpu ELF of for gfx950
|
||||
# unpack the complete kernel descriptor of an amdgpu ELF
|
||||
# https://rocm.docs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPUUsage.html#code-object-v3-kernel-descriptor
|
||||
import struct, pathlib
|
||||
import struct, pathlib, sys
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
def bits(x, lo, hi): return (x >> lo) & ((1 << (hi - lo + 1)) - 1)
|
||||
def assert_zero(x, lo, hi): assert bits(x, lo, hi) == 0
|
||||
|
||||
with open(fp:=pathlib.Path(__file__).parent/"lib", "rb") as f:
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
lib = f.read()
|
||||
|
||||
image, sections, relocs = elf_loader(lib)
|
||||
@@ -49,7 +49,7 @@ print("COMPUTE_PGM_RSRC3: 0x%08x" % pgm_rsrc3)
|
||||
print("COMPUTE_PGM_RSRC1: 0x%08x" % pgm_rsrc1)
|
||||
print("COMPUTE_PGM_RSRC2: 0x%08x" % pgm_rsrc2)
|
||||
|
||||
# rsrc 3
|
||||
# rsrc 3 (gfx950)
|
||||
|
||||
accum_offset_raw = bits(pgm_rsrc3, 0, 5)
|
||||
assert_zero(pgm_rsrc3, 6, 15)
|
||||
@@ -169,10 +169,10 @@ assert_zero(desc, 458, 459)
|
||||
uses_dynamic_stack = bits(desc, 459, 460)
|
||||
print("DESC.USES_DYNAMIC_STACK:", uses_dynamic_stack)
|
||||
|
||||
# gfx950 only
|
||||
assert_zero(desc, 460, 463)
|
||||
kernarg_preload_spec_length = bits(desc, 464, 470)
|
||||
print("DESC.KERNARG_PRELOAD_SPEC_LENGTH:", kernarg_preload_spec_length)
|
||||
|
||||
kernarg_preload_spec_offset = bits(desc, 471, 479)
|
||||
print("DESC.KERNARG_PRELOAD_SPEC_OFFSET:", kernarg_preload_spec_offset)
|
||||
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import BEAM, Timing, CI, Context
|
||||
from tinygrad import Variable, Tensor
|
||||
from tinygrad.helpers import BEAM, Timing, CI, prod
|
||||
from tinygrad import Variable, Device, Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.uop.ops import AxisType
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt.search import get_kernel_actions
|
||||
|
||||
def rand(*shape):
|
||||
return Tensor(np.random.rand(*shape).astype(np.float32))
|
||||
@@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase):
|
||||
a = (a + a) * a
|
||||
a.realize()
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tc_up(self):
|
||||
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
|
||||
size = max(tc.dims[0], tc.dims[1]) * 8
|
||||
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
|
||||
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
|
||||
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=int(up))
|
||||
upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]
|
||||
assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none"
|
||||
|
||||
def test_max_up(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
ast = a.schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
for max_up in (2, 4):
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
|
||||
for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]:
|
||||
assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -3,18 +3,21 @@
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import subprocess, struct, math, textwrap
|
||||
import subprocess, struct, math, textwrap, functools
|
||||
from tinygrad import Tensor, dtypes, Device, UOp
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.uop.ops import Ops, KernelInfo
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
from extra.assembly.amd.asm import waitcnt
|
||||
from test.testextra.test_cfg_viz import template
|
||||
|
||||
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
|
||||
lidx = UOp.special(n_threads, "lidx0")
|
||||
gidx = UOp.special(n_workgroups, "gidx0")
|
||||
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
|
||||
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
|
||||
src = "\n".join(inst.disasm() for inst in [
|
||||
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
|
||||
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
|
||||
s_endpgm()
|
||||
])
|
||||
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
|
||||
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
|
||||
car = CompiledRunner(prg)
|
||||
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
|
||||
car([out.uop.buffer], {}, wait=True)
|
||||
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
|
||||
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
|
||||
out.realize()
|
||||
return out.tolist()
|
||||
|
||||
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]
|
||||
|
||||
@@ -135,9 +135,14 @@ check_untyped_defs = true
|
||||
explicit_package_bases = true
|
||||
warn_unreachable = true
|
||||
warn_redundant_casts = true
|
||||
strict_equality = true
|
||||
# NOTE: had to comment this out to make mypy pass on both CI and OSX
|
||||
#warn_unused_ignores = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "extra.*"
|
||||
follow_imports = "skip"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
norecursedirs = [
|
||||
"extra",
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.device import Device, BufferSpec
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run")
|
||||
class TestQcom(unittest.TestCase):
|
||||
def test_image_pitch(self):
|
||||
dev = Device["QCOM"]
|
||||
|
||||
def __validate(imgdt, expected_pitch):
|
||||
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt))
|
||||
pitch = img.texture_info.pitch
|
||||
assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}"
|
||||
dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options)
|
||||
|
||||
# Match opencl pitches for perf
|
||||
__validate(dtypes.imageh((1, 201)), 0x680)
|
||||
__validate(dtypes.imageh((16, 216)), 0x700)
|
||||
__validate(dtypes.imageh((16, 9)), 0x80)
|
||||
__validate(dtypes.imageh((48, 64)), 0x200)
|
||||
__validate(dtypes.imageh((32, 128)), 0x400)
|
||||
__validate(dtypes.imageh((96, 128)), 0x400)
|
||||
__validate(dtypes.imageh((64, 256)), 0x840)
|
||||
__validate(dtypes.imageh((64, 9)), 0x80)
|
||||
__validate(dtypes.imageh((192, 256)), 0x840)
|
||||
__validate(dtypes.imageh((64, 768)), 0x1840)
|
||||
__validate(dtypes.imageh((256, 49)), 0x1C0)
|
||||
__validate(dtypes.imageh((128, 9)), 0x80)
|
||||
__validate(dtypes.imageh((16, 1024)), 0x2080)
|
||||
__validate(dtypes.imageh((64, 512)), 0x1040)
|
||||
__validate(dtypes.imageh((16, 512)), 0x1080)
|
||||
__validate(dtypes.imageh((132, 64)), 0x200)
|
||||
__validate(dtypes.imageh((4, 512)), 0x1200)
|
||||
__validate(dtypes.imageh((8, 512)), 0x1100)
|
||||
__validate(dtypes.imageh((128, 128)), 0x400)
|
||||
__validate(dtypes.imageh((32, 512)), 0x1040)
|
||||
__validate(dtypes.imageh((26, 64)), 0x200)
|
||||
__validate(dtypes.imageh((32, 516)), 0x1040)
|
||||
__validate(dtypes.imageh((32, 1024)), 0x2040)
|
||||
__validate(dtypes.imageh((16, 2048)), 0x4080)
|
||||
__validate(dtypes.imageh((8, 2048)), 0x4100)
|
||||
__validate(dtypes.imageh((4, 4096)), 0x8200)
|
||||
|
||||
__validate(dtypes.imagef((16, 49)), 0x380)
|
||||
__validate(dtypes.imagef((16, 1024)), 0x4080)
|
||||
__validate(dtypes.imagef((256, 64)), 0x400)
|
||||
__validate(dtypes.imagef((64, 512)), 0x2040)
|
||||
__validate(dtypes.imagef((16, 512)), 0x2080)
|
||||
__validate(dtypes.imagef((132, 64)), 0x400)
|
||||
__validate(dtypes.imagef((4, 512)), 0x2200)
|
||||
__validate(dtypes.imagef((4, 16)), 0x200)
|
||||
__validate(dtypes.imagef((2, 16)), 0x400)
|
||||
__validate(dtypes.imagef((8, 512)), 0x2100)
|
||||
__validate(dtypes.imagef((12, 64)), 0x400)
|
||||
__validate(dtypes.imagef((3, 32)), 0x400)
|
||||
__validate(dtypes.imagef((128, 128)), 0x840)
|
||||
__validate(dtypes.imagef((32, 512)), 0x2040)
|
||||
__validate(dtypes.imagef((8, 3072)), 0xC100)
|
||||
__validate(dtypes.imagef((4, 2048)), 0x8200)
|
||||
__validate(dtypes.imagef((4, 1024)), 0x4200)
|
||||
__validate(dtypes.imagef((4, 4096)), 0x10200)
|
||||
__validate(dtypes.imagef((10, 384)), 0x1900)
|
||||
__validate(dtypes.imagef((24, 64)), 0x400)
|
||||
__validate(dtypes.imagef((128, 12)), 0xC0)
|
||||
__validate(dtypes.imagef((10, 24)), 0x200)
|
||||
__validate(dtypes.imagef((1, 129)), 0x840)
|
||||
__validate(dtypes.imagef((1, 32)), 0x200)
|
||||
__validate(dtypes.imagef((1, 64)), 0x400)
|
||||
__validate(dtypes.imagef((1, 1239)), 0x4D80)
|
||||
__validate(dtypes.imagef((1, 1)), 0x40)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -44,6 +44,66 @@ class TestImageCopy(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageDType(unittest.TestCase):
|
||||
def test_image_pitch(self):
|
||||
def __validate(imgdt, expected_pitch):
|
||||
assert imgdt.pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{imgdt.pitch:X}, expected 0x{expected_pitch:X}"
|
||||
|
||||
# Match opencl pitches for perf
|
||||
__validate(dtypes.imageh((1, 201)), 0x680)
|
||||
__validate(dtypes.imageh((16, 216)), 0x700)
|
||||
__validate(dtypes.imageh((16, 9)), 0x80)
|
||||
__validate(dtypes.imageh((48, 64)), 0x200)
|
||||
__validate(dtypes.imageh((32, 128)), 0x400)
|
||||
__validate(dtypes.imageh((96, 128)), 0x400)
|
||||
__validate(dtypes.imageh((64, 256)), 0x840)
|
||||
__validate(dtypes.imageh((64, 9)), 0x80)
|
||||
__validate(dtypes.imageh((192, 256)), 0x840)
|
||||
__validate(dtypes.imageh((64, 768)), 0x1840)
|
||||
__validate(dtypes.imageh((256, 49)), 0x1C0)
|
||||
__validate(dtypes.imageh((128, 9)), 0x80)
|
||||
__validate(dtypes.imageh((16, 1024)), 0x2080)
|
||||
__validate(dtypes.imageh((64, 512)), 0x1040)
|
||||
__validate(dtypes.imageh((16, 512)), 0x1080)
|
||||
__validate(dtypes.imageh((132, 64)), 0x200)
|
||||
__validate(dtypes.imageh((4, 512)), 0x1200)
|
||||
__validate(dtypes.imageh((8, 512)), 0x1100)
|
||||
__validate(dtypes.imageh((128, 128)), 0x400)
|
||||
__validate(dtypes.imageh((32, 512)), 0x1040)
|
||||
__validate(dtypes.imageh((26, 64)), 0x200)
|
||||
__validate(dtypes.imageh((32, 516)), 0x1040)
|
||||
__validate(dtypes.imageh((32, 1024)), 0x2040)
|
||||
__validate(dtypes.imageh((16, 2048)), 0x4080)
|
||||
__validate(dtypes.imageh((8, 2048)), 0x4100)
|
||||
__validate(dtypes.imageh((4, 4096)), 0x8200)
|
||||
|
||||
__validate(dtypes.imagef((16, 49)), 0x380)
|
||||
__validate(dtypes.imagef((16, 1024)), 0x4080)
|
||||
__validate(dtypes.imagef((256, 64)), 0x400)
|
||||
__validate(dtypes.imagef((64, 512)), 0x2040)
|
||||
__validate(dtypes.imagef((16, 512)), 0x2080)
|
||||
__validate(dtypes.imagef((132, 64)), 0x400)
|
||||
__validate(dtypes.imagef((4, 512)), 0x2200)
|
||||
__validate(dtypes.imagef((4, 16)), 0x200)
|
||||
__validate(dtypes.imagef((2, 16)), 0x400)
|
||||
__validate(dtypes.imagef((8, 512)), 0x2100)
|
||||
__validate(dtypes.imagef((12, 64)), 0x400)
|
||||
__validate(dtypes.imagef((3, 32)), 0x400)
|
||||
__validate(dtypes.imagef((128, 128)), 0x840)
|
||||
__validate(dtypes.imagef((32, 512)), 0x2040)
|
||||
__validate(dtypes.imagef((8, 3072)), 0xC100)
|
||||
__validate(dtypes.imagef((4, 2048)), 0x8200)
|
||||
__validate(dtypes.imagef((4, 1024)), 0x4200)
|
||||
__validate(dtypes.imagef((4, 4096)), 0x10200)
|
||||
__validate(dtypes.imagef((10, 384)), 0x1900)
|
||||
__validate(dtypes.imagef((24, 64)), 0x400)
|
||||
__validate(dtypes.imagef((128, 12)), 0xC0)
|
||||
__validate(dtypes.imagef((10, 24)), 0x200)
|
||||
__validate(dtypes.imagef((1, 129)), 0x840)
|
||||
__validate(dtypes.imagef((1, 32)), 0x200)
|
||||
__validate(dtypes.imagef((1, 64)), 0x400)
|
||||
__validate(dtypes.imagef((1, 1239)), 0x4D80)
|
||||
__validate(dtypes.imagef((1, 1)), 0x40)
|
||||
|
||||
def test_image_and_back(self):
|
||||
data = Tensor.randn(9*27*4).realize()
|
||||
tst = data.numpy()
|
||||
|
||||
@@ -288,17 +288,22 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
|
||||
|
||||
def test_conv2d_ceildiv_edge_case(self):
|
||||
v = Variable('v', 11, 50_000)
|
||||
val = 39601
|
||||
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
|
||||
weight = Tensor.randn(256, 22, 12)
|
||||
# tests symbolic ceildiv in conv2d output shape calculation
|
||||
# val=79 triggers the edge case where old ceildiv simplifies incorrectly: old gives floor=12, correct ceildiv=13
|
||||
v = Variable('v', 11, 100)
|
||||
val = 79
|
||||
x_full = Tensor.randn(1, 8, 100)
|
||||
weight = Tensor.randn(16, 8, 12)
|
||||
|
||||
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
# symbolic version
|
||||
result = x_full[:, :, :v.bind(val)].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
var_val = {v.expr: val}
|
||||
shape = tuple(sym_infer(s, var_val) for s in result.shape)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
|
||||
# TODO: test output is correct
|
||||
self.assertEqual(shape, (1, 16, 13))
|
||||
|
||||
# concrete version for comparison
|
||||
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
|
||||
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.helpers import SPLIT_REDUCEOP
|
||||
|
||||
class TestTensorUOp(unittest.TestCase):
|
||||
def test_fromcpu_shape_tracker(self):
|
||||
def helper(a: np.ndarray):
|
||||
print(a.shape, a.strides, a.flags.c_contiguous)
|
||||
b = Tensor(a).uop
|
||||
assert b.shape == a.shape
|
||||
np.testing.assert_equal(a, Tensor(b).numpy())
|
||||
|
||||
for ndims in range(1, 4):
|
||||
a = np.random.randn(*(4,)*ndims).astype(np.float32)
|
||||
for stride in [-2, 1, 2]:
|
||||
for start in [0, 1]:
|
||||
helper(a[(slice(start, None, stride),)*ndims])
|
||||
|
||||
def test_shuffle_pad_ops_cmpeq(self):
|
||||
y = Tensor([1]).cat(Tensor([1]) == 0).numpy()
|
||||
z = Tensor([1, 0]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_div(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy()
|
||||
z = Tensor([1, 0.5]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_log(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).log()).numpy()
|
||||
z = Tensor([1, 0]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_shuffle_pad_ops_exp(self):
|
||||
y = Tensor([1]).cat(Tensor([1]).exp()).numpy()
|
||||
z = Tensor([1, np.e]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
def test_device_0_is_the_same_device(self):
|
||||
a = Tensor([1, 2, 3], f"{Device.DEFAULT}")
|
||||
b = Tensor([1, 2, 3], f"{Device.DEFAULT}:0")
|
||||
assert a.device == b.device
|
||||
|
||||
def test_shrink_const_into_zero(self):
|
||||
# regression test to make sure the shapetracker is preserved
|
||||
a = Tensor.zeros(4,4,4).shrink((None, (0,0), None))
|
||||
b = Tensor.zeros(4,1,4)
|
||||
c = a.cat(b, dim=1)
|
||||
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
|
||||
|
||||
def test_shrink_const_then_cast(self):
|
||||
# regression test to make sure the shapetracker is preserved
|
||||
a = Tensor.zeros(4,4,4).shrink((None, (0,0), None)).cast(dtypes.int32)
|
||||
b = Tensor.zeros(4,1,4)
|
||||
c = a.cat(b, dim=1)
|
||||
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
|
||||
|
||||
def test_const_dtype(self):
|
||||
lb: UOp = Tensor([1], dtype=dtypes.int).uop
|
||||
assert lb.const_like(1).base.arg == 1
|
||||
assert type(lb.const_like(1).base.arg) is int
|
||||
|
||||
lb: UOp = Tensor([1], dtype=dtypes.float).uop
|
||||
assert lb.const_like(1).base.arg == 1.0
|
||||
assert type(lb.const_like(1).base.arg) is float
|
||||
|
||||
def test_contiguous_alu(self):
|
||||
a = Tensor.randn(2, 2).realize()
|
||||
b = Tensor.randn(2, 2).realize()
|
||||
add = (a+b).contiguous()
|
||||
out = add+2
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 2)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2)
|
||||
|
||||
# NOTE: contiguous on a buffer collapses
|
||||
@unittest.skip("contiguous on a buffer no longer collapses")
|
||||
def test_contiguous_empty(self):
|
||||
empty = Tensor.empty(1).contiguous()
|
||||
sched = empty.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
|
||||
def test_contiguous_folded_alu(self):
|
||||
a = Tensor.empty(8, 8)
|
||||
# NOTE: the buffer for mul_0 late folds to just a CONST
|
||||
mul_0 = a*0
|
||||
out = mul_0.shrink(((4, 8), (0, 8))).contiguous()
|
||||
out.realize()
|
||||
self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist())
|
||||
|
||||
@unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP")
|
||||
class TestReduceOp(unittest.TestCase):
|
||||
def test_no_split_reduce_kernel(self):
|
||||
a = Tensor.rand(4, 4).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 1
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -73,8 +73,6 @@ class TestTensorVariable(unittest.TestCase):
|
||||
ret = Tensor.arange(vv.bind(4), 7)
|
||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||
|
||||
# TODO: add vmin/vmax pattern for symbolic denominator
|
||||
@unittest.expectedFailure
|
||||
def test_symbolic_arange_sym_step(self):
|
||||
vv = Variable("step", 1, 3)
|
||||
ret = Tensor.arange(0, 10, vv.bind(2))
|
||||
@@ -86,6 +84,18 @@ class TestTensorVariable(unittest.TestCase):
|
||||
ret = Tensor.arange(begin.bind(4), end.bind(7))
|
||||
self.assertListEqual(ret[:3].tolist(), [4,5,6])
|
||||
|
||||
def test_symbolic_arange_three_vars(self):
|
||||
begin = Variable("b", 0, 5)
|
||||
end = Variable("e", 10, 20)
|
||||
step = Variable("s", 1, 3)
|
||||
ret = Tensor.arange(begin.bind(2), end.bind(14), step.bind(3))
|
||||
self.assertListEqual(ret[:4].tolist(), [2,5,8,11])
|
||||
|
||||
def test_symbolic_full(self):
|
||||
vv = Variable("x", 1, 10).bind(5)
|
||||
t = Tensor.full((3,), vv)
|
||||
self.assertListEqual(t.tolist(), [5,5,5])
|
||||
|
||||
def test_variable_empty(self):
|
||||
v = Variable("i", 1, 10)
|
||||
# TODO: Tensor creation from unbound variable should assert
|
||||
|
||||
@@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase):
|
||||
_test_value(0)
|
||||
_test_value(0.0000009)
|
||||
|
||||
class TestFloat16Log2(unittest.TestCase):
|
||||
"""Tests for native float16 log2 implementation (no float32 cast)"""
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_basic(self):
|
||||
# basic values
|
||||
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
for val in test_values:
|
||||
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
|
||||
expected = np.log2(np.float16(val))
|
||||
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
|
||||
def test_float16_log2_special(self):
|
||||
# special values: inf, -inf, nan, 0, negative
|
||||
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
|
||||
# log2(inf) = inf
|
||||
assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0])
|
||||
# log2(0) = -inf
|
||||
assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf
|
||||
# log2(negative) = nan
|
||||
assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0])
|
||||
# log2(nan) = nan
|
||||
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
|
||||
def test_float16_log2_denormal(self):
|
||||
# test values near and below float16 min normal (6.1e-5)
|
||||
# these exercise the denormal handling path with 2^10 scaling
|
||||
test_values = [1e-4, 6e-5, 1e-5]
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
for val in test_values:
|
||||
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
|
||||
expected = np.log2(np.float16(val))
|
||||
# denormals have lower precision due to float16 limitations
|
||||
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
|
||||
|
||||
class TestTranscendentalSchedule(unittest.TestCase):
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
|
||||
def test_transcendental_sin_fusion(self):
|
||||
|
||||
@@ -478,143 +478,6 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for u in uops:
|
||||
self.assertNotEqual(u.dtype, dtypes.long)
|
||||
|
||||
def test_in_out_of_bounds_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_gated_store(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
|
||||
v = Variable("v", 0, 20)
|
||||
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
|
||||
to_uops_list([st0])
|
||||
|
||||
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([st1])
|
||||
|
||||
@unittest.skip("if not allowed in graph")
|
||||
def test_in_bounds_access_gated_local(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
# Define buffers
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||
|
||||
# Define indices, valids and barrier
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
|
||||
|
||||
gate = (gidx<400) & (lidx<8)
|
||||
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
||||
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
# Load from local memory (after the IF/barrier)
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
||||
|
||||
# Store to global memory
|
||||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||
to_uops_list([global_store])
|
||||
|
||||
def test_load_with_float_in_index(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
ridx = UOp.range(20, 0)
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
|
||||
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
|
||||
i = (ldfloat+3.14).cast(dtypes.int)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
|
||||
|
||||
def test_load_cast_to_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
def test_load_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
def test_out_of_bounds_off_by_one_access(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_bounds_access_with_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
|
||||
to_uops_list([ld0, ld1])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_symbolic_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
i = Variable("i", 1, 80)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
|
||||
to_uops_list([ld0])
|
||||
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld0])
|
||||
|
||||
def test_in_out_of_bounds_access_index_load(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
|
||||
to_uops_list([ld1])
|
||||
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_bounds_with_loaded_bool(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
|
||||
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||
ld0 = glbl0.index(gidx0, ptr=True).load()
|
||||
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
|
||||
with self.assertRaises(RuntimeError): to_uops_list([ld1])
|
||||
|
||||
def test_fold_gated_load(self):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||
|
||||
@@ -2,6 +2,7 @@ import ctypes, gzip, unittest, timeit, pickle
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.helpers import ceildiv
|
||||
from tinygrad.tensor import Tensor, get_shape
|
||||
import numpy as np
|
||||
|
||||
@@ -120,6 +121,25 @@ class TestRoundUp(unittest.TestCase):
|
||||
self.assertEqual(round_up(232, 24984), 24984)
|
||||
self.assertEqual(round_up(24984, 232), 25056)
|
||||
|
||||
class TestCeilDiv(unittest.TestCase):
|
||||
def test_int(self):
|
||||
self.assertEqual(ceildiv(10, 3), 4)
|
||||
self.assertEqual(ceildiv(9, 3), 3)
|
||||
self.assertEqual(ceildiv(0, 5), 0)
|
||||
self.assertEqual(ceildiv(1, 5), 1)
|
||||
def test_symbolic(self):
|
||||
# tests that ceildiv with UOp uses (num + amt - 1) // amt formula for non-negative num
|
||||
v = Variable('v', 0, 100)
|
||||
result = ceildiv(v, 6)
|
||||
self.assertEqual(result.render(), "((v+5)//6)")
|
||||
def test_symbolic_negative_offset(self):
|
||||
# tests ceildiv(v-5, 6) which is used in conv2d output shape
|
||||
# old implementation incorrectly simplified -(x//-y) to ((v+1)//6-1) for v-5
|
||||
# new implementation uses (v-5+5)//6 = v//6 which is correct
|
||||
v = Variable('v', 11, 100)
|
||||
result = ceildiv(v - 5, 6)
|
||||
self.assertEqual(result.render(), "(v//6)")
|
||||
|
||||
class TestCount(unittest.TestCase):
|
||||
def test_count_basic(self):
|
||||
c = count(3)
|
||||
|
||||
@@ -65,6 +65,30 @@ class TestLinAlg(unittest.TestCase):
|
||||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R],a)
|
||||
|
||||
def test_qr_zero_column(self):
|
||||
a = Tensor([[0.0, 1.0], [0.0, 2.0]]).realize()
|
||||
Q,R = a.qr()
|
||||
assert not np.isnan(Q.numpy()).any()
|
||||
assert not np.isnan(R.numpy()).any()
|
||||
orthogonality_helper(Q)
|
||||
reconstruction_helper([Q,R], a)
|
||||
|
||||
def test_svd_identity(self):
|
||||
for a in (Tensor.eye(2), Tensor.zeros(2, 2)):
|
||||
a = a.realize()
|
||||
U,S,V = a.svd()
|
||||
assert not np.isnan(U.numpy()).any()
|
||||
assert not np.isnan(S.numpy()).any()
|
||||
assert not np.isnan(V.numpy()).any()
|
||||
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
|
||||
reconstruction_helper([U, s_diag, V], a)
|
||||
|
||||
def test_svd_rank1(self):
|
||||
a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize()
|
||||
U, S, V = a.svd()
|
||||
np.testing.assert_allclose(S.numpy(), [np.sqrt(10), 0.0], atol=1e-4, rtol=1e-4)
|
||||
reconstruction_helper([U, S.unsqueeze(-2) * Tensor.eye(2), V], a)
|
||||
|
||||
def test_newton_schulz(self):
|
||||
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
|
||||
sizes = [(2,2), (3,2), (2,3), (2,2,2)]
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestMaskedShapeTracker(unittest.TestCase):
|
||||
class TestMaskedTensor(unittest.TestCase):
|
||||
def test_mul_masked(self):
|
||||
a = Tensor([1,1,1,1,1])
|
||||
b = Tensor([1,1]).pad(((0,3),))
|
||||
c = a*b
|
||||
assert c.shape == a.shape
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
@@ -16,7 +15,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
b = Tensor([1,1]).pad(((0,3),))
|
||||
c = a*b
|
||||
assert c.shape == a.shape
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
|
||||
|
||||
@@ -24,7 +22,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
|
||||
a = Tensor([1,1]).pad(((0,2),))
|
||||
b = Tensor([1,1]).pad(((0,2),))
|
||||
c = a+b
|
||||
#assert c.uop.st.views[0].mask is not None
|
||||
ret = c.data()
|
||||
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]
|
||||
|
||||
@@ -128,6 +128,26 @@ class TestProgressBar(unittest.TestCase):
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
if n > 5: break
|
||||
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
def test_si_boundary(self, mock_terminal_size, mock_stderr):
|
||||
"""Test SI formatting at boundaries (e.g., 999.5 -> 1.00k, not 1000)"""
|
||||
ncols = 80
|
||||
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
|
||||
|
||||
# Test rates at the boundary: 999 stays as "999", 999.5+ becomes "1.00k"
|
||||
for rate in [999, 999.4, 999.5, 1000, 1001]:
|
||||
mock_stderr.truncate(0)
|
||||
mock_stderr.seek(0)
|
||||
elapsed = 1.0 / rate
|
||||
# Need 3 perf_counter calls: init st, init update, final update
|
||||
with patch('time.perf_counter', side_effect=[0, 0, elapsed]):
|
||||
bar = tinytqdm(desc="Test", total=1, unit_scale=True, rate=10**9)
|
||||
bar.update(1, close=True)
|
||||
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
|
||||
tqdm_output = tqdm.format_meter(n=1, total=1, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=True)
|
||||
self._compare_bars(tinytqdm_output, tqdm_output)
|
||||
|
||||
@unittest.skip("this is flaky")
|
||||
@patch('sys.stderr', new_callable=StringIO)
|
||||
@patch('shutil.get_terminal_size')
|
||||
|
||||
179
test/unit/test_validate_oob.py
Normal file
179
test/unit/test_validate_oob.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import AddrSpace
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import Ops, UOp, AxisType
|
||||
from test.test_uops import to_uops_list
|
||||
|
||||
class TestValidateOOB(unittest.TestCase):
|
||||
"""Test z3 validation of index bounds for different ALU ops and patterns."""
|
||||
|
||||
# basic index patterns
|
||||
def test_const_index(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out
|
||||
|
||||
def test_variable_index(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative
|
||||
|
||||
def test_range_with_mask(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||
|
||||
def test_variable_with_mask(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
v = Variable("v", -5, 80)
|
||||
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked
|
||||
|
||||
def test_gated_store(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
v = Variable("v", 0, 20)
|
||||
to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
|
||||
|
||||
# ALU ops in index
|
||||
def test_idiv(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
|
||||
|
||||
def test_mod(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob
|
||||
|
||||
def test_shr(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
||||
|
||||
def test_shl(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
|
||||
r = UOp.range(8, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob
|
||||
|
||||
def test_and(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
r = UOp.range(100, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
|
||||
|
||||
def test_max(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
|
||||
|
||||
def test_xor_in_mask(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
r = UOp.range(32, 0, AxisType.GLOBAL)
|
||||
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob
|
||||
|
||||
# cast patterns
|
||||
def test_float_cast_in_index(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
r = UOp.range(20, 0)
|
||||
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
|
||||
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
|
||||
|
||||
def test_bool_cast_in_mask(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
|
||||
r = UOp.range(20, 0)
|
||||
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
|
||||
|
||||
# load result as index/mask
|
||||
def test_load_as_index(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1)
|
||||
r = UOp.range(42, 0, AxisType.GLOBAL)
|
||||
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
|
||||
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
|
||||
|
||||
def test_load_bool_as_mask(self):
|
||||
with Context(IGNORE_OOB=0, SPEC=2):
|
||||
buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1)
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
|
||||
ld_bool = buf_bool.index(gidx, ptr=True).load()
|
||||
with self.assertRaises(RuntimeError):
|
||||
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
|
||||
|
||||
# skipped tests (moved from test_uop_graph.py)
|
||||
@unittest.skip("if not allowed in graph")
|
||||
def test_in_bounds_access_gated_local(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
# Define buffers
|
||||
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
|
||||
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
|
||||
|
||||
# Define indices, valids and barrier
|
||||
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
|
||||
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
|
||||
|
||||
gate = (gidx<400) & (lidx<8)
|
||||
|
||||
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
|
||||
|
||||
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
|
||||
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
|
||||
|
||||
# Load from local memory (after the IF/barrier)
|
||||
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
|
||||
|
||||
# Store to global memory
|
||||
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
|
||||
to_uops_list([global_store])
|
||||
|
||||
@unittest.skip("Bool load is not supported yet")
|
||||
def test_load_mask(self):
|
||||
with Context(IGNORE_OOB=0):
|
||||
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
|
||||
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
|
||||
ridx = UOp.range(20, 0)
|
||||
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
|
||||
to_uops_list([ld0])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -166,8 +166,10 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
|
||||
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
|
||||
|
||||
# rewrite to prg
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||
if ast.op is Ops.PROGRAM: prg = ast
|
||||
else:
|
||||
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
|
||||
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
|
||||
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
|
||||
|
||||
# create the ProgramSpec
|
||||
|
||||
@@ -45,6 +45,7 @@ class Scheduler:
|
||||
ret = Scheduler(self.ast, self.ren)
|
||||
ret.dont_use_locals = self.dont_use_locals
|
||||
ret.applied_opts = self.applied_opts[:]
|
||||
if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core
|
||||
return ret
|
||||
|
||||
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
||||
@@ -307,6 +308,7 @@ class Scheduler:
|
||||
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
|
||||
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
|
||||
self.ast = self.ast.substitute({reduceop: tc_uop})
|
||||
self.tensor_core = tc
|
||||
return axes
|
||||
return None
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
|
||||
# *** external API ***
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_kernel_actions(s:Scheduler, include_0=True) -> dict[int, Scheduler]:
|
||||
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]:
|
||||
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024)
|
||||
kernel_actions = actions.copy()
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Final, ClassVar, Callable, Literal
|
||||
import math, struct, ctypes, functools
|
||||
from dataclasses import dataclass, fields
|
||||
from tinygrad.helpers import getenv, prod
|
||||
from tinygrad.helpers import getenv, prod, round_up, next_power2
|
||||
from enum import Enum, auto
|
||||
|
||||
class InvalidTypeMetaClass(type):
|
||||
@@ -101,6 +101,15 @@ class ImageDType(PtrDType):
|
||||
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
|
||||
return self
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
|
||||
@property
|
||||
def pitch(self):
|
||||
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
||||
|
||||
granularity = 128 if self.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
|
||||
|
||||
class dtypes:
|
||||
@staticmethod
|
||||
|
||||
@@ -125,7 +125,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
|
||||
|
||||
# NOTE: ctx is the buffers
|
||||
si_lowerer = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
|
||||
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
|
||||
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
|
||||
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
|
||||
|
||||
@@ -38,18 +38,18 @@ def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
||||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
||||
def fully_flatten(l):
|
||||
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
||||
if hasattr(l, "shape") and l.shape == (): return [l[()]]
|
||||
flattened = []
|
||||
for li in l: flattened.extend(fully_flatten(li))
|
||||
return flattened
|
||||
return [l]
|
||||
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
|
||||
return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
|
||||
def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fst[-1] == ')' and _is_balanced(fst[1:-1]) else fst
|
||||
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
|
||||
def ceildiv(num, amt):
|
||||
# use (num + amt - 1) // amt when num is a UOp and non-negative to avoid C/Python division mismatch
|
||||
if hasattr(num, 'vmin') and num.vmin >= 0 and (amt > 0 if isinstance(amt, int) else amt.vmin > 0): return (num + amt - 1) // amt
|
||||
return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
|
||||
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
|
||||
# cstyle div and mod
|
||||
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
|
||||
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
|
||||
@@ -87,9 +87,7 @@ def word_wrap(x, wrap=80):
|
||||
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
|
||||
return x[:i] + "\n" + word_wrap(x[i:], wrap)
|
||||
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
|
||||
def panic(e:Exception|None=None):
|
||||
if e is None: raise RuntimeError("PANIC!")
|
||||
raise e
|
||||
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
|
||||
@@ -149,9 +147,7 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
|
||||
def temp(x:str, append_user:bool=False) -> str:
|
||||
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
||||
|
||||
def stderr_log(msg):
|
||||
sys.stderr.write(msg)
|
||||
sys.stderr.flush()
|
||||
def stderr_log(msg:str): print(msg, end='', file=sys.stderr, flush=True)
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
@@ -512,7 +508,9 @@ class tqdm(Generic[T]):
|
||||
if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
|
||||
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
|
||||
def SI(x):
|
||||
return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
|
||||
if not x: return '0.00'
|
||||
v = f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')
|
||||
return (f"{x/1000**(int(g)+1):.3f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)+1]) if v == "1000" else v+' kMGTPEZY'[int(g)].strip()
|
||||
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
|
||||
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
|
||||
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"
|
||||
|
||||
@@ -133,7 +133,7 @@ string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
|
||||
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
|
||||
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
|
||||
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
|
||||
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
|
||||
])
|
||||
|
||||
@@ -180,7 +180,7 @@ class PTXRenderer(Renderer):
|
||||
self.uops = uops
|
||||
|
||||
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
|
||||
nonlocal c, r
|
||||
nonlocal c
|
||||
prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_"
|
||||
c[prefix] += 1
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
@@ -230,7 +230,7 @@ class PTXRenderer(Renderer):
|
||||
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
|
||||
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
|
||||
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
|
||||
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
|
||||
if prefix: r[u] = ssa(prefix, u, dtype)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class HCQGraph(MultiGraphRunner):
|
||||
|
||||
for (j,i), input_idx in self.input_replace.items():
|
||||
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
|
||||
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
|
||||
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, image=self.hcq_bufs[j][i].image) # Create fake buffer with variable
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)
|
||||
|
||||
@@ -834,11 +834,11 @@ class PCIIface(PCIIfaceBase):
|
||||
|
||||
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA:
|
||||
assert idx <= 3, "only 4 SDMA queues supported in am"
|
||||
pv = self.dev_impl.sdma.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
|
||||
doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0 + idx * 0xA * 4), pipe=0, queue=idx)
|
||||
pv, doorbell_index = self.dev_impl.sdma.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr,
|
||||
wptr_addr=gart.va_addr+wptr, pipe=0, queue=idx)
|
||||
else:
|
||||
pv = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
|
||||
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0,
|
||||
pv, doorbell_index = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr,
|
||||
wptr_addr=gart.va_addr+wptr, eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, pipe=0,
|
||||
queue=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
|
||||
|
||||
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
|
||||
from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.renderer.nir import IR3Renderer
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
|
||||
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
|
||||
from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC
|
||||
from tinygrad.runtime.support.system import System
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
@@ -25,7 +25,7 @@ def _qreg_exec(__reg, __val=0, **kwargs):
|
||||
return __val
|
||||
qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'})
|
||||
|
||||
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
|
||||
def ctz(v): return (v & -v).bit_length() - 1
|
||||
|
||||
def parity(val: int):
|
||||
for i in range(4,1,-1): val ^= val >> (1 << i)
|
||||
@@ -192,9 +192,8 @@ class QCOMArgsState(HCQArgsState):
|
||||
super().__init__(buf, prg, bufs, vals=vals)
|
||||
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
|
||||
|
||||
ubos, uavs = [b for b in bufs if b.texture_info is None], [b for b in bufs if b.texture_info is not None]
|
||||
ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None]
|
||||
ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:])
|
||||
|
||||
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
|
||||
|
||||
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
|
||||
@@ -205,8 +204,15 @@ class QCOMArgsState(HCQArgsState):
|
||||
for i, b in enumerate(ubos): self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=prg.buf_offs[i])
|
||||
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)])
|
||||
|
||||
self.bind_sints_to_buf(*flatten([b.texture_info.desc + ([0] * 8) for b in texs]), buf=self.buf, fmt='I', offset=prg.tex_off)
|
||||
self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off)
|
||||
def _tex(b, ibo=False):
|
||||
fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt),
|
||||
qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr),
|
||||
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0]
|
||||
|
||||
self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off)
|
||||
self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off)
|
||||
|
||||
class QCOMProgram(HCQProgram):
|
||||
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
|
||||
@@ -305,28 +311,10 @@ class QCOMTextureInfo:
|
||||
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
|
||||
|
||||
class QCOMAllocator(HCQAllocatorBase):
|
||||
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
|
||||
def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer:
|
||||
# Recalculate real size for texture
|
||||
if options.image is not None:
|
||||
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
|
||||
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
|
||||
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
|
||||
|
||||
granularity = 128 if options.image.itemsize == 4 else 256
|
||||
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
|
||||
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
|
||||
size = pitch * imgh
|
||||
|
||||
buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
|
||||
|
||||
if options.image is not None:
|
||||
tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
|
||||
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
|
||||
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
|
||||
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
|
||||
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
|
||||
return buf
|
||||
if opts.image is not None: size = opts.image.pitch* opts.image.shape[0]
|
||||
return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image)
|
||||
|
||||
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):
|
||||
with cpu_profile(prof_text, self.dev.device, is_copy=True):
|
||||
@@ -335,13 +323,13 @@ class QCOMAllocator(HCQAllocatorBase):
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def _copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
stride, pitch = (dest.image.shape[1] * 4 * dest.image.itemsize, dest.image.pitch) if dest.image else (src.nbytes, src.nbytes)
|
||||
self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}")
|
||||
|
||||
def _copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.dev.synchronize()
|
||||
|
||||
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
|
||||
stride, pitch = (src.image.shape[1] * 4 * src.image.itemsize, src.image.pitch) if src.image else (src.size, src.size)
|
||||
self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY")
|
||||
|
||||
def _as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
@@ -388,7 +376,7 @@ class QCOMDevice(HCQCompiled):
|
||||
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
|
||||
functools.partial(QCOMComputeQueue, self), None)
|
||||
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
|
||||
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False, **kwargs) -> HCQBuffer:
|
||||
flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
|
||||
if uncached: flags |= flag("KGSL_CACHEMODE", kgsl.KGSL_CACHEMODE_UNCACHED)
|
||||
|
||||
@@ -396,15 +384,15 @@ class QCOMDevice(HCQCompiled):
|
||||
va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000)
|
||||
|
||||
if fill_zeroes: ctypes.memset(va_addr, 0, size)
|
||||
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self)
|
||||
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self, **kwargs)
|
||||
|
||||
def _gpu_map(self, ptr:int, size:int) -> HCQBuffer:
|
||||
def _gpu_map(self, ptr:int, size:int, **kwargs) -> HCQBuffer:
|
||||
ptr_aligned, size_aligned = (ptr & ~0xfff), round_up(size + (ptr & 0xfff), 0x1000)
|
||||
try:
|
||||
mapinfo = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
|
||||
return HCQBuffer(mapinfo.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mapinfo, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
|
||||
mi = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
|
||||
return HCQBuffer(mi.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mi, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
|
||||
except OSError as e:
|
||||
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
|
||||
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
|
||||
raise RuntimeError("Failed to map external pointer to GPU memory") from e
|
||||
|
||||
def _gpu_free(self, mem:HCQBuffer):
|
||||
|
||||
@@ -281,9 +281,10 @@ class AM_GFX(AM_IP):
|
||||
self._grbm_select(inst=xcc)
|
||||
for xcc in range(self.xccs): self.adev.regGCVM_CONTEXT0_CNTL.write(0, inst=xcc)
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int,
|
||||
aql:bool) -> int:
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, pipe:int, queue:int,
|
||||
aql:bool) -> tuple[int, int]:
|
||||
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0)
|
||||
doorbell = am.AMDGPU_NAVI10_DOORBELL_MEC_RING0
|
||||
restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1)
|
||||
restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0
|
||||
if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.")
|
||||
@@ -327,7 +328,7 @@ class AM_GFX(AM_IP):
|
||||
self._grbm_select(inst=xcc)
|
||||
|
||||
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1, inst=xcc)
|
||||
return restore_ptr // 16
|
||||
return restore_ptr // 16, doorbell
|
||||
|
||||
def set_clockgating_state(self):
|
||||
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
|
||||
@@ -447,9 +448,9 @@ class AM_SDMA(AM_IP):
|
||||
time.sleep(0.01)
|
||||
self.adev.regGRBM_SOFT_RESET.write(0x0)
|
||||
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int) -> int:
|
||||
# Setup the ring
|
||||
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, pipe:int, queue:int) -> tuple[int, int]:
|
||||
reg, inst = ("regSDMA_GFX", pipe+queue*4) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
|
||||
doorbell = am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0 + (pipe+queue*4) * 0xA
|
||||
self.sdma_reginst.append((reg, inst))
|
||||
|
||||
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x1, inst=inst)
|
||||
@@ -464,7 +465,7 @@ class AM_SDMA(AM_IP):
|
||||
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP][:2]!=(4,4) else {}),
|
||||
rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4, rb_enable=1, rb_priv=1, rb_size=(ring_size//4).bit_length()-1, inst=inst)
|
||||
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=1, inst=inst)
|
||||
return self.adev.reg(f"{reg}_RB_WPTR").read(inst=inst) | (self.adev.reg(f"{reg}_RB_WPTR_HI").read(inst=inst) << 32)
|
||||
return self.adev.reg(f"{reg}_RB_WPTR").read(inst=inst) | (self.adev.reg(f"{reg}_RB_WPTR_HI").read(inst=inst) << 32), doorbell
|
||||
|
||||
class AM_PSP(AM_IP):
|
||||
def init_sw(self):
|
||||
|
||||
@@ -8,6 +8,7 @@ from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEve
|
||||
from tinygrad.uop.ops import sym_infer, sint, UOp
|
||||
from tinygrad.runtime.autogen import libc
|
||||
from tinygrad.runtime.support.memory import BumpAllocator
|
||||
from tinygrad.dtype import ImageDType
|
||||
|
||||
class MMIOInterface:
|
||||
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
|
||||
@@ -455,14 +456,14 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
|
||||
|
||||
class HCQBuffer:
|
||||
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
|
||||
def __init__(self, va_addr:sint, size:int, image:ImageDType|None=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
|
||||
owner:HCQCompiled|None=None):
|
||||
self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view
|
||||
self.va_addr, self.size, self.image, self.meta, self._base, self.view = va_addr, size, image, meta, _base, view
|
||||
self._devs, self.owner = ([owner] if owner is not None else []), owner
|
||||
self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices
|
||||
|
||||
def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer:
|
||||
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta,
|
||||
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, image=self.image, meta=self.meta,
|
||||
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
|
||||
|
||||
def cpu_view(self) -> MMIOInterface:
|
||||
|
||||
@@ -127,7 +127,7 @@ class Tensor(OpMixin):
|
||||
|
||||
# create a UOp from the different types of inputs
|
||||
if isinstance(data, UOp):
|
||||
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
|
||||
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}"
|
||||
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
|
||||
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
|
||||
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
|
||||
@@ -3637,11 +3637,13 @@ class Tensor(OpMixin):
|
||||
Q = Tensor.eye(m, dtype=self.dtype).reshape((1,) * len(b_shape) + (m, m)).expand(b_shape + (m, m)).contiguous()
|
||||
for i in range(min(m, n)):
|
||||
x = R[..., i:m, i].contiguous() # TODO: without contigous this can silently be wrong, should at least assert
|
||||
s = -x[..., 0].sign()
|
||||
u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
|
||||
w = x.unsqueeze(-1) / u1.reshape(b_shape + (1, 1))
|
||||
norm = x.square().sum(-1).sqrt()
|
||||
s = (x[..., 0] != 0).where(-x[..., 0].sign(), -1)
|
||||
u1 = x[..., 0] - s * norm
|
||||
w = x.unsqueeze(-1) / (norm != 0).where(u1, 1).reshape(b_shape + (1, 1))
|
||||
w[..., 0, 0] = 1
|
||||
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + (1, 1))
|
||||
tau = (-s * u1 / (norm != 0).where(norm, 1)).reshape(b_shape + (1, 1))
|
||||
tau = (norm != 0).reshape(b_shape + (1, 1)).where(tau, 0)
|
||||
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
|
||||
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
|
||||
return Q,R
|
||||
@@ -3668,8 +3670,10 @@ class Tensor(OpMixin):
|
||||
#compute the jacobi rotations for each pairing
|
||||
gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
|
||||
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
|
||||
tau = (beta - alpha) / (2 * gamma)
|
||||
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
|
||||
rot = gamma != 0
|
||||
tau = (beta - alpha) / (2 * rot.where(gamma, 1))
|
||||
t = (tau != 0).where(tau.sign(), 1) / (tau.abs() + (1 + tau.square()).sqrt())
|
||||
t = rot.where(t, 0)
|
||||
c = 1 / (1 + t.square()).sqrt()
|
||||
s = c * t
|
||||
#apply the rotations
|
||||
@@ -3686,9 +3690,9 @@ class Tensor(OpMixin):
|
||||
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
|
||||
#extract singular values and sort. construct U from Q
|
||||
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
|
||||
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
|
||||
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U, V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
|
||||
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
|
||||
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
|
||||
V = V.gather(-1, new_indices).realize()
|
||||
|
||||
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
|
||||
padded_u[..., 0:num, 0:num] = U
|
||||
|
||||
@@ -223,26 +223,26 @@ def xlog2(d:UOp) -> UOp:
|
||||
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
||||
"""
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
|
||||
# TODO: float16 denormal need float32 to achieve precision
|
||||
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
|
||||
# float16 uses 2^10 for denormal scaling (2^64 overflows), float32/64 use 2^64
|
||||
denormal_exp = 10 if d.dtype.scalar() == dtypes.float16 else 64
|
||||
FLT_MIN = d.const_like({dtypes.float16: 6.1e-5, dtypes.float32: 1e-4, dtypes.float64: 1e-4}[d.dtype.scalar()])
|
||||
is_denormal = d<FLT_MIN
|
||||
a = is_denormal.where(d * (2 ** 64), d)
|
||||
a = is_denormal.where(d * (2 ** denormal_exp), d)
|
||||
|
||||
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
|
||||
m = ldexp3k(a, -e)
|
||||
e = is_denormal.where(e - 64, e)
|
||||
e = is_denormal.where(e - denormal_exp, e)
|
||||
|
||||
x = (m - 1.0) / (m + 1.0)
|
||||
x2 = x * x
|
||||
if d.dtype.scalar() == dtypes.float64:
|
||||
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
|
||||
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
|
||||
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
|
||||
r = t * (x * x2) + e + x * 2.885390081777926774
|
||||
else:
|
||||
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
|
||||
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
|
||||
r = t * (x * x2) + (s_hi + s_lo)
|
||||
# s_lo term (x*3.27e-08) only for float32 - underflows in float16
|
||||
r = t * (x * x2) + e + x * 2.8853900432586669922 + (x * 3.2734474483568488616e-08 if d.dtype.scalar() == dtypes.float32 else 0)
|
||||
|
||||
# log2(Inf) = Inf
|
||||
r = d.ne(math.inf).where(r, r.const_like(math.inf))
|
||||
|
||||
@@ -11,9 +11,8 @@ try:
|
||||
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
|
||||
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
|
||||
def z3_xor(a,b):
|
||||
if isinstance(a, z3.BoolRef): return a^b
|
||||
assert a==-1 or b==-1, "xor can only be used in indexing if one of the arguments is -1"
|
||||
return -a-1 if b==-1 else -b-1
|
||||
assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}"
|
||||
return a^b
|
||||
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
|
||||
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
|
||||
Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
|
||||
@@ -34,7 +33,6 @@ try:
|
||||
(UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0].ctx), None)),
|
||||
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0].ctx), None)),
|
||||
# casts from floats create new variables
|
||||
(UPat(Ops.CAST, dtypes.bool, src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx: (z3.Bool(f"cast{len(ctx[1])}",ctx=ctx[0].ctx), None)),
|
||||
(UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
|
||||
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
|
||||
# A comparison between floats introduces a new bool variable
|
||||
@@ -67,11 +65,10 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
||||
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
|
||||
if 0<=idx.vmin and idx.vmax<sz: return True
|
||||
|
||||
# WEBGPU has a BITCAST in the index. TODO: fix
|
||||
if any(x.op is Ops.BITCAST for x in idx.toposort() | gate.toposort()): return True
|
||||
|
||||
# PTX uses absolute addresses (pointer cast to long), skip validation
|
||||
if any(x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType) for x in idx.toposort()): return True
|
||||
# TODO: validate these
|
||||
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
|
||||
for x in idx.toposort() | gate.toposort():
|
||||
if x.op is Ops.BITCAST or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
|
||||
if not z3_imported: raise ImportError("bounds checking requires z3 >= 4.12.4, use IGNORE_OOB=1 to disable, or \"pip install 'z3-solver>=4.12.4\"")
|
||||
solver = z3.Solver(ctx=z3.Context())
|
||||
|
||||
Reference in New Issue
Block a user