Merge origin/master

This commit is contained in:
George Hotz
2026-01-01 17:15:45 +00:00
45 changed files with 6643 additions and 41334 deletions

View File

@@ -49,19 +49,19 @@ jobs:
- name: Print macOS version
run: sw_vers
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
run: BENCHMARK_LOG=stable_diffusion JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run Stable Diffusion without fp16
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
run: BENCHMARK_LOG=stable_diffusion_fp32 JIT=1 ASSERT_MIN_STEP_TIME=720 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing
- name: Run Stable Diffusion v2
# TODO: very slow step time
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=4500 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
run: BENCHMARK_LOG=stable_diffusion_v2 JIT=1 ASSERT_MIN_STEP_TIME=4500 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing
# process replay can't capture this, the graph is too large
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=5000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=5000 CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing
- name: Run model inference benchmark
run: METAL=1 NOCLANG=1 python3.11 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: BIG=2 MPS=1 python3.11 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
run: BIG=2 MPS=1 python3.11 test/speed/external_test_speed_v_torch.py
- name: Test tensor cores
run: METAL=1 python3.11 test/opt/test_tensor_cores.py
- name: Test AMX tensor cores
@@ -71,84 +71,59 @@ jobs:
DEBUG=2 CPU=1 CPU_LLVM=0 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
DEBUG=2 CPU=1 CPU_LLVM=1 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
- name: Run Tensor Core GEMM (float)
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (half)
run: DEBUG=2 SHOULD_USE_TC=1 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
run: DEBUG=2 SHOULD_USE_TC=1 HALF=1 python3.11 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (bfloat16)
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
run: DEBUG=2 SHOULD_USE_TC=1 BFLOAT16=1 python3.11 extra/gemm/simple_matmul.py
- name: Fuzz Padded Tensor Core GEMM
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
BENCHMARK_LOG=llama_nojit JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
run: BENCHMARK_LOG=llama_beam JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run quantized LLaMA
run: |
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
BENCHMARK_LOG=llama_int8 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8
BENCHMARK_LOG=llama_nf4 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4
- name: Run quantized LLaMA3
run: |
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8 | tee llama3_int8.txt
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4 | tee llama3_nf4.txt
BENCHMARK_LOG=llama3_int8 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize int8
BENCHMARK_LOG=llama3_nf4 python3.11 examples/llama3.py --size 8B --temperature 0 --benchmark --quantize nf4
#- name: Run LLaMA 7B on 4 (virtual) GPUs
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
# run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2_nojit JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 JIT=1 ASSERT_MIN_STEP_TIME=13 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: BENCHMARK_LOG=gpt2_half HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: BENCHMARK_LOG=gpt2_half_beam HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run OLMoE
run: BENCHMARK_LOG=olmoe python3.11 examples/olmoe.py
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py
# NOTE: this is failing in CI. it is not failing on my machine and I don't really have a way to debug it
# the error is "RuntimeError: Internal Error (0000000e:Internal Error)"
#- name: Run 10 CIFAR training steps
# run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=3000 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
# run: BENCHMARK_LOG=cifar_10steps JIT=1 ASSERT_MIN_STEP_TIME=3000 STEPS=10 python3.11 examples/hlb_cifar10.py
#- name: Run 10 CIFAR training steps w HALF
# run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=3000 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
# run: BENCHMARK_LOG=cifar_10steps_half JIT=2 ASSERT_MIN_STEP_TIME=3000 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py
#- name: Run 10 CIFAR training steps w BF16
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py
# TODO: too slow
# - name: Run 10 CIFAR training steps w winograd
# run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
# run: BENCHMARK_LOG=cifar_10steps_wino JIT=1 ASSERT_MIN_STEP_TIME=150 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py
- uses: actions/upload-artifact@v4
with:
name: Speed (Mac)
path: |
onnx_inference_speed.csv
torch_speed.txt
llama_unjitted.txt
llama_jitted.txt
llama_beam.txt
llama_int8.txt
llama_nf4.txt
llama3_int8.txt
llama3_nf4.txt
llama_four_gpu.txt
gpt2_unjitted.txt
gpt2_jitted.txt
gpt2_half.txt
gpt2_half_beam.txt
matmul.txt
matmul_half.txt
matmul_bfloat16.txt
sd.txt
sd_no_fp16.txt
sdv2.txt
sdxl.txt
beautiful_mnist.txt
train_cifar.txt
train_cifar_half.txt
train_cifar_bf16.txt
train_cifar_wino.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3.11 process_replay.py
@@ -215,7 +190,7 @@ jobs:
- name: Run model inference benchmark
run: NV=1 CAPTURE_PROCESS_REPLAY=0 NOCLANG=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: NV=1 CAPTURE_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
run: NV=1 CAPTURE_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py
- name: Test speed vs theoretical
run: NV=1 IGNORE_BEAM_CACHE=1 CCACHE=0 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
- name: Test benchmark allreduce
@@ -226,79 +201,58 @@ jobs:
NV=1 NV_PTX=1 ALLOW_TF32=1 python3 test/opt/test_tensor_cores.py
- name: Run Tensor Core GEMM (CUDA)
run: |
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_bfloat16.txt
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_tf32.txt
CUDA=1 SHOULD_USE_TC=1 FP8E4M3=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_fp8.txt
CUDA=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
CUDA=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
CUDA=1 SHOULD_USE_TC=1 ALLOW_TF32=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
CUDA=1 SHOULD_USE_TC=1 FP8E4M3=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (PTX)
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
run: NV=1 NV_PTX=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (NV)
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
run: NV=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Test NV=1
run: DEBUG=2 NV=1 python -m pytest -rA test/test_tiny.py
- name: Test CUDA=1
run: DEBUG=2 CUDA=1 python -m pytest -rA test/test_tiny.py
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
run: BENCHMARK_LOG=stable_diffusion NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
# TODO: too slow
# - name: Run SDXL
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
# run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=2000 CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA
run: |
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
BENCHMARK_LOG=llama NV=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
BENCHMARK_LOG=llama_nojit NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama NV=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA with BEAM
run: BENCHMARK_LOG=llama_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
run: BENCHMARK_LOG=llama_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
run: BENCHMARK_LOG=llama3_beam_4gpu NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run quantized LLaMA3
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8 | tee llama3_fp8.txt
run: BENCHMARK_LOG=llama3_fp8 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --temperature 0 --benchmark --quantize fp8
# - name: Run LLaMA-3 8B on 6 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-2 70B
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
run: time BENCHMARK_LOG=mixtral NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2_nojit NV=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 NV=1 JIT=1 ASSERT_MIN_STEP_TIME=4 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: BENCHMARK_LOG=gpt2_half NV=1 HALF=1 ASSERT_MIN_STEP_TIME=6 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: BENCHMARK_LOG=gpt2_half_beam NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA)
path: |
onnx_inference_speed.csv
torch_speed.txt
matmul.txt
matmul_bfloat16.txt
matmul_tf32.txt
matmul_ptx.txt
matmul_nv.txt
sd.txt
sdxl.txt
llama_unjitted.txt
llama_jitted.txt
llama_beam.txt
llama3_beam.txt
llama3_four_gpu.txt
llama3_six_gpu.txt
llama3_fp8.txt
llama_2_70B.txt
mixtral.txt
gpt2_unjitted.txt
gpt2_jitted.txt
gpt2_half.txt
gpt2_half_beam.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -339,42 +293,28 @@ jobs:
- name: HEVC Decode Benchmark
run: VALIDATE=1 MAX_FRAMES=100 NV=1 PYTHONPATH=. python3 extra/hevc/decode.py
- name: Train MNIST
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
run: time PYTHONPATH=. NV=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 python3 examples/hlb_cifar10.py
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=110 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=110 NV=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
- name: Run 10 CIFAR training steps w BF16
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=120 NV=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
# - name: Run 10 CIFAR training steps w winograd
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=350 NV=1 CAPTURE_PROCESS_REPLAY=0 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=350 NV=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu CAPTURE_PROCESS_REPLAY=0 NV=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
run: time BENCHMARK_LOG=cifar_6gpu CAPTURE_PROCESS_REPLAY=0 NV=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run MLPerf resnet eval on training data
run: time BENCHMARK_LOG=resnet_eval NV=1 MODEL=resnet python3 examples/mlperf/model_eval.py
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
run: BENCHMARK_LOG=resnet_10steps NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
run: BENCHMARK_LOG=resnet_10steps NV=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
run: BENCHMARK_LOG=resnet_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
run: BENCHMARK_LOG=resnet_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf Bert training steps (6 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA Training)
path: |
beautiful_mnist.txt
train_cifar.txt
train_cifar_half.txt
train_cifar_bf16.txt
train_cifar_wino.txt
train_cifar_one_gpu.txt
train_cifar_six_gpu.txt
train_resnet.txt
train_resnet_one_gpu.txt
train_bert.txt
run: BENCHMARK_LOG=bert_10steps_6gpu NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -426,7 +366,7 @@ jobs:
#- name: Test speed vs torch
# run: |
# python3 -c "import torch; print(torch.__version__)"
# LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py | tee torch_speed.txt
# LD_PRELOAD="/opt/rocm/lib/libhsa-runtime64.so" HSA=1 BIG=2 TORCHCUDA=1 python3 test/speed/external_test_speed_v_torch.py
- name: Test speed vs theoretical
run: AMD=1 IGNORE_BEAM_CACHE=1 CCACHE=0 BEAM_DEBUG=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py --durations=20
- name: Test tensor cores AMD_LLVM=0
@@ -437,7 +377,7 @@ jobs:
- name: Run Tensor Core GEMM (AMD)
run: |
AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
- name: Test AMD=1
run: DEBUG=2 AMD=1 python -m pytest -rA test/test_tiny.py
#- name: Test HIP=1
@@ -452,61 +392,39 @@ jobs:
- name: Test AM warm start time
run: time AMD=1 python3 test/test_tiny.py TestTiny.test_plus
- name: Run Stable Diffusion
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
run: BENCHMARK_LOG=stable_diffusion ASSERT_MIN_STEP_TIME=550 AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing
- name: Run SDXL
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: BENCHMARK_LOG=stable_diffusion_xl ASSERT_MIN_STEP_TIME=3200 CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing
- name: Run LLaMA 7B
run: |
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
BENCHMARK_LOG=llama AMD=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
BENCHMARK_LOG=llama_nojit AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=llama AMD=1 JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA 7B with BEAM
run: BENCHMARK_LOG=llama_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
run: BENCHMARK_LOG=llama_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 4 GPUs
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing
# - name: Run LLaMA 7B on 6 GPUs
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run LLaMA-3 8B BEAM
run: BENCHMARK_LOG=llama3_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
run: BENCHMARK_LOG=llama3_beam AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: BENCHMARK_LOG=llama3_beam_4gpu AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
run: BENCHMARK_LOG=llama3_beam_4gpu AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
# - name: Run LLaMA-3 8B on 6 GPUs
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0
#- name: Restore amdgpu
# run: sudo modprobe amdgpu
# - name: Run LLaMA-2 70B
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run Mixtral 8x7B
run: time BENCHMARK_LOG=mixtral AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
run: time BENCHMARK_LOG=mixtral AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing
- name: Run GPT2
run: |
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
BENCHMARK_LOG=gpt2_nojit AMD=1 JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
BENCHMARK_LOG=gpt2 AMD=1 JIT=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: BENCHMARK_LOG=gpt2_half AMD=1 HALF=1 ASSERT_MIN_STEP_TIME=5 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run GPT2 w HALF/BEAM
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD)
path: |
onnx_inference_speed.csv
torch_speed.txt
llama_unjitted.txt
llama_jitted.txt
llama_beam.txt
llama3_beam.txt
llama3_four_gpu.txt
llama3_six_gpu.txt
llama_2_70B.txt
gpt2_unjitted.txt
gpt2_jitted.txt
gpt2_half.txt
gpt2_half_beam.txt
matmul.txt
matmul_amd.txt
sd.txt
sdxl.txt
mixtral.txt
run: BENCHMARK_LOG=gpt2_half_beam AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -543,31 +461,20 @@ jobs:
- name: reset process replay
run: test/external/process_replay/reset.py
- name: Train MNIST
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py
- name: Run 10 CIFAR training steps
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
run: BENCHMARK_LOG=cifar_10steps ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 python3 examples/hlb_cifar10.py
- name: Run 10 CIFAR training steps w HALF
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: BENCHMARK_LOG=cifar_10steps_half ASSERT_MIN_STEP_TIME=200 AMD=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
# - name: Run 10 CIFAR training steps w BF16
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
# run: BENCHMARK_LOG=cifar_10steps_bf16 ASSERT_MIN_STEP_TIME=288 AMD=1 STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py
# TODO: too slow
# - name: Run 10 CIFAR training steps w winograd
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
# run: BENCHMARK_LOG=cifar_10steps_half_wino ASSERT_MIN_STEP_TIME=66 AMD=1 WINO=1 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_one_gpu.txt
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run full CIFAR training steps w 6 GPUS
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee train_cifar_six_gpu.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD Training)
path: |
beautiful_mnist.txt
train_cifar.txt
train_cifar_half.txt
train_cifar_bf16.txt
train_cifar_wino.txt
train_cifar_one_gpu.txt
train_cifar_six_gpu.txt
run: time BENCHMARK_LOG=cifar_6gpu AMD=1 DEFAULT_FLOAT=HALF STEPS=350 BS=1536 GPUS=6 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -606,19 +513,12 @@ jobs:
- name: Run MLPerf resnet eval
run: time BENCHMARK_LOG=resnet_eval AMD=1 MODEL=resnet python3 examples/mlperf/model_eval.py
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
run: BENCHMARK_LOG=resnet_10steps AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet_one_gpu.txt
run: BENCHMARK_LOG=resnet_10steps AMD=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf ResNet50 training steps (6 gpu)
run: BENCHMARK_LOG=resnet_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py | tee train_resnet.txt
run: BENCHMARK_LOG=resnet_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=1536 GPUS=6 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf Bert training steps (6 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee train_bert.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD MLPerf)
path: |
train_resnet.txt
train_resnet_one_gpu.txt
train_bert.txt
run: BENCHMARK_LOG=bert_10steps_6gpu AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=72 GPUS=6 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -708,7 +608,7 @@ jobs:
# AMD=1 AMD_LLVM=1 python3 test/test_linearizer.py test/opt/test_tensor_cores.py
# AMD=1 SHOULD_USE_TC=1 BFLOAT16=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Run Tensor Core GEMM (AMD)
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py | tee am_matmul_amd.txt
run: AMD=1 SHOULD_USE_TC=1 HALF=1 DEBUG=2 ATOL=2e-2 python3 extra/gemm/simple_matmul.py
- name: Test AMD=1
run: DEBUG=2 AMD=1 python -m pytest -rA test/test_tiny.py
- name: Test DISK copy time
@@ -718,20 +618,12 @@ jobs:
AMD=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyDefaulttoCPUJit
AMD=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyCPUtoDefaultJit
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee am_train_cifar_one_gpu.txt
run: time BENCHMARK_LOG=cifar AMD=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
# - name: Run 10 MLPerf ResNet50 training steps (1 gpu)
# run: BENCHMARK_LOG=resnet_10steps AMD=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee am_train_resnet_one_gpu.txt
# run: BENCHMARK_LOG=resnet_10steps AMD=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf Bert training steps (1 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee am_train_bert_one_gpu.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AM Driver)
path: |
am_matmul_amd.txt
am_train_cifar_one_gpu.txt
am_train_resnet_one_gpu.txt
am_train_bert_one_gpu.txt
run: BENCHMARK_LOG=bert_10steps AMD=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
@@ -778,21 +670,13 @@ jobs:
NV=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyDefaulttoCPUJit
NV=1 GRAPH_ONE_KERNEL=1 PYTHONPATH=. NSZ=8192 python3 test/speed/external_test_copy_speed.py TestCopySpeed.testCopyCPUtoDefaultJit
- name: Test LLAMA-3
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --benchmark --temperature 0 | tee nv_llama3_beam.txt
run: BENCHMARK_LOG=llama3_beam NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --benchmark --temperature 0
- name: Run full CIFAR training w 1 GPU
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py | tee nv_train_cifar_one_gpu.txt
run: time BENCHMARK_LOG=cifar NV=1 DEFAULT_FLOAT=HALF STEPS=1000 TARGET_EVAL_ACC_PCT=93.0 python3 examples/hlb_cifar10.py
- name: Run 10 MLPerf ResNet50 training steps (1 gpu)
run: BENCHMARK_LOG=resnet_10steps NV=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py | tee nv_train_resnet_one_gpu.txt
run: BENCHMARK_LOG=resnet_10steps NV=1 MNISTMOCK=1 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=256 GPUS=1 MODEL=resnet python3 examples/mlperf/model_train.py
- name: Run 10 MLPerf Bert training steps (1 gpu)
# TODO: remove BERT_LAYERS once scheduler is fast
run: BENCHMARK_LOG=bert_10steps NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py | tee nv_train_bert_one_gpu.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NV Driver)
path: |
nv_llama3_beam.txt
nv_train_cifar_one_gpu.txt
nv_train_resnet_one_gpu.txt
nv_train_bert_one_gpu.txt
run: BENCHMARK_LOG=bert_10steps NV=1 CAPTURE_PROCESS_REPLAY=0 DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 GPUS=1 BERT_LAYERS=2 MODEL=bert python3 examples/mlperf/model_train.py
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View File

@@ -5,6 +5,7 @@ env:
CAPTURE_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: ${{ github.workspace }}
IGNORE_OOB: 0
on:
push:
@@ -36,6 +37,8 @@ jobs:
name: Docs
runs-on: ubuntu-22.04
timeout-minutes: 10
env:
IGNORE_OOB: 1
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -215,7 +218,6 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 10
# TODO: run the pre-commit hook to replace a lot of this
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -230,13 +232,13 @@ jobs:
- name: Lint with ruff
run: |
pip3 install --upgrade --force-reinstall ruff==0.14.10
python3 -m ruff check .
pre-commit run ruff --all-files
python3 -m ruff check examples/mlperf/ --ignore E501
python3 -m ruff check extra/thunder/tiny/ --ignore E501 --ignore F841 --ignore E722
python3 -m ruff check extra/torch_backend/backend.py
- name: Run mypy
run: |
python -m mypy --strict-equality --lineprecision-report .
python -m mypy --lineprecision-report .
cat lineprecision.txt
- name: Run TYPED=1
run: TYPED=1 python -c "import tinygrad"
@@ -307,7 +309,7 @@ jobs:
deps: testing_unit
python-version: '3.14'
- name: Test SPEC=2
run: IGNORE_OOB=0 SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
run: SPEC=2 pytest --maxfail=10 -n auto --durations=30 --ignore=test/models --ignore test/test_custom_kernel.py --ignore test/unit/test_hashing.py --timeout 60 -k "not test_setitem_big" --splits 2 --group ${{ matrix.group }}
fuzzing:
name: Fuzzing
@@ -470,6 +472,8 @@ jobs:
name: Test LLM
runs-on: ubuntu-24.04
timeout-minutes: 15
env:
IGNORE_OOB: 1
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -584,8 +588,6 @@ jobs:
name: Linux (WebGPU)
runs-on: ubuntu-22.04
timeout-minutes: 20
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -822,7 +824,6 @@ jobs:
NV_PTX: 1
NV: 1
FORWARD_ONLY: 1
IGNORE_OOB: 0
run: |
python3 -m pytest -n=auto test/device/test_hcq.py test/test_tiny.py --durations=20
- name: Run process replay tests
@@ -832,8 +833,6 @@ jobs:
name: MacOS (WebGPU)
runs-on: macos-14
timeout-minutes: 10
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4
@@ -911,8 +910,6 @@ jobs:
name: Windows (${{ matrix.backend }})
runs-on: windows-latest
timeout-minutes: 15
env:
IGNORE_OOB: 0
steps:
- name: Checkout Code
uses: actions/checkout@v4

View File

@@ -16,7 +16,7 @@ repos:
pass_filenames: false
- id: mypy
name: mypy
entry: python3 -m mypy tinygrad/ --strict-equality
entry: python3 -m mypy
language: system
always_run: true
pass_filenames: false

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,48 +1,56 @@
# library for RDNA3 assembly DSL
# mypy: ignore-errors
from __future__ import annotations
import struct, math
import struct, math, re
from enum import IntEnum
from functools import cache, cached_property
from typing import overload, Annotated, TypeVar, Generic
from extra.assembly.amd.autogen.rdna3.enum import (VOP1Op, VOP2Op, VOP3Op, VOP3SDOp, VOP3POp, VOPCOp, VOPDOp, SOP1Op, SOP2Op,
SOPCOp, SOPKOp, SOPPOp, SMEMOp, DSOp, FLATOp, MUBUFOp, MTBUFOp, MIMGOp, VINTERPOp)
# Common masks and bit conversion functions
MASK32, MASK64 = 0xffffffff, 0xffffffffffffffff
def _f32(i): return struct.unpack("<f", struct.pack("<I", i & MASK32))[0]
_struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
def _f32(i): return _struct_f.unpack(_struct_I.pack(i & MASK32))[0]
def _i32(f):
if isinstance(f, int): f = float(f)
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
try: return struct.unpack("<I", struct.pack("<f", f))[0]
try: return _struct_I.unpack(_struct_f.pack(f))[0]
except (OverflowError, struct.error): return 0x7f800000 if f > 0 else 0xff800000
def _sext(v, b): return v - (1 << b) if v & (1 << (b - 1)) else v
def _f16(i): return struct.unpack("<e", struct.pack("<H", i & 0xffff))[0]
def _f16(i): return _struct_e.unpack(_struct_H.pack(i & 0xffff))[0]
def _i16(f):
if math.isnan(f): return 0x7e00
if math.isinf(f): return 0x7c00 if f > 0 else 0xfc00
try: return struct.unpack("<H", struct.pack("<e", f))[0]
try: return _struct_H.unpack(_struct_e.pack(f))[0]
except (OverflowError, struct.error): return 0x7c00 if f > 0 else 0xfc00
def _f64(i): return struct.unpack("<d", struct.pack("<Q", i & MASK64))[0]
def _f64(i): return _struct_d.unpack(_struct_Q.pack(i & MASK64))[0]
def _i64(f):
if math.isnan(f): return 0x7ff8000000000000
if math.isinf(f): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
try: return struct.unpack("<Q", struct.pack("<d", f))[0]
try: return _struct_Q.unpack(_struct_d.pack(f))[0]
except (OverflowError, struct.error): return 0x7ff0000000000000 if f > 0 else 0xfff0000000000000
# Instruction spec - register counts and dtypes derived from instruction names
import re
_REGS = {'B32': 1, 'B64': 2, 'B96': 3, 'B128': 4, 'B256': 8, 'B512': 16,
'F32': 1, 'I32': 1, 'U32': 1, 'F64': 2, 'I64': 2, 'U64': 2,
'F16': 1, 'I16': 1, 'U16': 1, 'B16': 1, 'I8': 1, 'U8': 1, 'B8': 1}
_CVT_RE = re.compile(r'CVT_([FIUB]\d+)_([FIUB]\d+)$')
_MAD_MUL_RE = re.compile(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$')
_PACK_RE = re.compile(r'PACK_([FIUB]\d+)_([FIUB]\d+)$')
_DST_SRC_RE = re.compile(r'_([FIUB]\d+)_([FIUB]\d+)$')
_SINGLE_RE = re.compile(r'_([FIUB](?:32|64|16|8|96|128|256|512))$')
@cache
def _suffix(name: str) -> tuple[str | None, str | None]:
name = name.upper()
if m := re.search(r'CVT_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'(?:MAD|MUL)_([IU]\d+)_([IU]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'PACK_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
# Generic dst_src pattern: S_BCNT0_I32_B64, S_BITREPLICATE_B64_B32, V_FREXP_EXP_I32_F64, etc.
if m := re.search(r'_([FIUB]\d+)_([FIUB]\d+)$', name): return m.group(1), m.group(2)
if m := re.search(r'_([FIUB](?:32|64|16|8|96|128|256|512))$', name): return m.group(1), m.group(1)
if m := _CVT_RE.search(name): return m.group(1), m.group(2)
if m := _MAD_MUL_RE.search(name): return m.group(1), m.group(2)
if m := _PACK_RE.search(name): return m.group(1), m.group(2)
if m := _DST_SRC_RE.search(name): return m.group(1), m.group(2)
if m := _SINGLE_RE.search(name): return m.group(1), m.group(1)
return None, None
_SPECIAL_REGS = {
'V_LSHLREV_B64': (2, 1, 2, 1), 'V_LSHRREV_B64': (2, 1, 2, 1), 'V_ASHRREV_I64': (2, 1, 2, 1),
@@ -75,27 +83,33 @@ _SPECIAL_DTYPE = {
# RDNA4 CVT_PK_F32 instructions: source is 8-bit packed as 16-bit operand
'V_CVT_PK_F32_BF8': ('F32', 'B16', None, None), 'V_CVT_PK_F32_FP8': ('F32', 'B16', None, None),
}
@cache
def spec_regs(name: str) -> tuple[int, int, int, int]:
name = name.upper()
if name in _SPECIAL_REGS: return _SPECIAL_REGS[name]
if 'SAD' in name and 'U8' in name and 'QSAD' not in name and 'MQSAD' not in name: return 1, 1, 1, 1
uname = name.upper()
if uname in _SPECIAL_REGS: return _SPECIAL_REGS[uname]
if 'SAD' in uname and 'U8' in uname and 'QSAD' not in uname and 'MQSAD' not in uname: return 1, 1, 1, 1
dst_suf, src_suf = _suffix(name)
return _REGS.get(dst_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1), _REGS.get(src_suf, 1)
@cache
def spec_dtype(name: str) -> tuple[str | None, str | None, str | None, str | None]:
name = name.upper()
if name in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[name]
if 'SAD' in name and ('U8' in name or 'U16' in name) and 'QSAD' not in name and 'MQSAD' not in name: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in name or '_CMPX_' in name:
uname = name.upper()
if uname in _SPECIAL_DTYPE: return _SPECIAL_DTYPE[uname]
if 'SAD' in uname and ('U8' in uname or 'U16' in uname) and 'QSAD' not in uname and 'MQSAD' not in uname: return 'U32', 'U32', 'U32', 'U32'
if '_CMP_' in uname or '_CMPX_' in uname:
dst_suf, src_suf = _suffix(name)
return 'EXEC' if '_CMPX_' in name else 'VCC', src_suf, src_suf, None
return 'EXEC' if '_CMPX_' in uname else 'VCC', src_suf, src_suf, None
dst_suf, src_suf = _suffix(name)
return dst_suf, src_suf, src_suf, src_suf
_F16_RE = re.compile(r'_[FIUB]16(?:_|$)')
_F64_RE = re.compile(r'_[FIUB]64(?:_|$)')
@cache
def spec_is_16bit(name: str) -> bool:
name = name.upper()
if 'SAD' in name or 'PACK' in name or '_PK_' in name or 'SAT_PK' in name or 'DOT2' in name: return False
if '_F32' in name or '_I32' in name or '_U32' in name or '_B32' in name: return False # mixed ops like V_DOT2ACC_F32_F16
return bool(re.search(r'_[FIUB]16(?:_|$)', name))
def spec_is_64bit(name: str) -> bool: return bool(re.search(r'_[FIUB]64(?:_|$)', name.upper()))
uname = name.upper()
if 'SAD' in uname or 'PACK' in uname or '_PK_' in uname or 'SAT_PK' in uname or 'DOT2' in uname: return False
if '_F32' in uname or '_I32' in uname or '_U32' in uname or '_B32' in uname: return False
return bool(_F16_RE.search(uname))
@cache
def spec_is_64bit(name: str) -> bool: return bool(_F64_RE.search(name.upper()))
_3SRC = {'FMA', 'MAD', 'MIN3', 'MAX3', 'MED3', 'DIV_FIX', 'DIV_FMAS', 'DIV_SCALE', 'SAD', 'LERP', 'ALIGN', 'CUBE', 'BFE', 'BFI',
'PERM_B32', 'PERMLANE', 'CNDMASK', 'XOR3', 'OR3', 'ADD3', 'LSHL_OR', 'AND_OR', 'LSHL_ADD', 'ADD_LSHL', 'XAD', 'MAXMIN',
'MINMAX', 'MAXIMUMMINIMUM', 'MINIMUMMAXIMUM', 'MAXIMUM3', 'MINIMUM3', 'DOT2', 'DOT4', 'DOT8', 'WMMA', 'CVT_PK_U8', 'MULLIT', 'CO_CI'}
@@ -535,21 +549,25 @@ class Inst:
assert cls_name in self._enum_map, f"no enum map for {cls_name}"
return self._enum_map[cls_name](val)
@property
@cached_property
def op_name(self) -> str:
op = self.op
return op.name if hasattr(op, 'name') else ''
def dst_regs(self) -> int: return spec_regs(self.op_name)[0]
def src_regs(self, n: int) -> int: return spec_regs(self.op_name)[n + 1]
@cached_property
def _spec_regs(self) -> tuple[int, int, int, int]: return spec_regs(self.op_name)
@cached_property
def _spec_dtype(self) -> tuple[str | None, str | None, str | None, str | None]: return spec_dtype(self.op_name)
def dst_regs(self) -> int: return self._spec_regs[0]
def src_regs(self, n: int) -> int: return self._spec_regs[n + 1]
def num_srcs(self) -> int: return spec_num_srcs(self.op_name)
def dst_dtype(self) -> str | None: return spec_dtype(self.op_name)[0]
def src_dtype(self, n: int) -> str | None: return spec_dtype(self.op_name)[n + 1]
def is_src_16(self, n: int) -> bool: return self.src_regs(n) == 1 and is_dtype_16(self.src_dtype(n))
def is_src_64(self, n: int) -> bool: return self.src_regs(n) == 2
def dst_dtype(self) -> str | None: return self._spec_dtype[0]
def src_dtype(self, n: int) -> str | None: return self._spec_dtype[n + 1]
def is_src_16(self, n: int) -> bool: return self._spec_regs[n + 1] == 1 and is_dtype_16(self._spec_dtype[n + 1])
def is_src_64(self, n: int) -> bool: return self._spec_regs[n + 1] == 2
def is_16bit(self) -> bool: return spec_is_16bit(self.op_name)
def is_64bit(self) -> bool: return spec_is_64bit(self.op_name)
def is_dst_16(self) -> bool: return self.dst_regs() == 1 and is_dtype_16(self.dst_dtype())
def is_dst_16(self) -> bool: return self._spec_regs[0] == 1 and is_dtype_16(self._spec_dtype[0])
class Inst32(Inst): pass
class Inst64(Inst): pass

View File

@@ -1,8 +1,9 @@
# RDNA3 emulator - executes compiled pseudocode from AMD ISA PDF
# mypy: ignore-errors
from __future__ import annotations
import ctypes, struct
from extra.assembly.amd.dsl import Inst, RawImm, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
import ctypes
from extra.assembly.amd.dsl import Inst, unwrap, FLOAT_ENC, MASK32, MASK64, _f32, _i32, _sext, _f16, _i16, _f64, _i64
from extra.assembly.amd.pcode import Reg
from extra.assembly.amd.asm import detect_format
from extra.assembly.amd.autogen.rdna3.gen_pcode import get_compiled_functions
from extra.assembly.amd.autogen.rdna3.ins import (SOP1, SOP2, SOPC, SOPK, SOPP, SMEM, VOP1, VOP2, VOP3, VOP3SD, VOP3P, VOPC, DS, FLAT, VOPD,
@@ -178,24 +179,21 @@ def exec_scalar(st: WaveState, inst: Inst) -> int:
s0 = st.rsrc64(ssrc0, 0) if inst.is_src_64(0) else (st.rsrc(ssrc0, 0) if not isinstance(inst, (SOPK, SOPP)) else (st.rsgpr(inst.sdst) if isinstance(inst, SOPK) else 0))
s1 = st.rsrc64(inst.ssrc1, 0) if inst.is_src_64(1) else (st.rsrc(inst.ssrc1, 0) if isinstance(inst, (SOP2, SOPC)) else inst.simm16 if isinstance(inst, SOPK) else 0)
d0 = st.rsgpr64(sdst) if inst.dst_regs() == 2 and sdst is not None else (st.rsgpr(sdst) if sdst is not None else 0)
exec_mask = st.exec_mask
literal = inst.simm16 if isinstance(inst, (SOPK, SOPP)) else st.literal
# Execute compiled function - pass PC in bytes for instructions that need it
# For wave32, mask VCC and EXEC to 32 bits since only the lower 32 bits are relevant
pc_bytes = st.pc * 4
vcc32, exec32 = st.vcc & MASK32, exec_mask & MASK32
result = fn(s0, s1, 0, d0, st.scc, vcc32, 0, exec32, literal, None, {}, pc=pc_bytes)
# Create Reg objects for compiled function - mask VCC/EXEC to 32 bits for wave32
result = fn(Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc & MASK32), 0, Reg(st.exec_mask & MASK32), literal, None, PC=Reg(st.pc * 4))
# Apply results
if sdst is not None:
(st.wsgpr64 if result.get('d0_64') else st.wsgpr)(sdst, result['d0'])
if 'scc' in result: st.scc = result['scc']
if 'exec' in result: st.exec_mask = result['exec']
if 'new_pc' in result:
# Apply results - extract values from returned Reg objects
if sdst is not None and 'D0' in result:
(st.wsgpr64 if inst.dst_regs() == 2 else st.wsgpr)(sdst, result['D0']._val)
if 'SCC' in result: st.scc = result['SCC']._val & 1
if 'EXEC' in result: st.exec_mask = result['EXEC']._val
if 'PC' in result:
# Convert absolute byte address to word delta
# new_pc is where we want to go, st.pc is current position, inst._words will be added after
new_pc_words = result['new_pc'] // 4
pc_val = result['PC']._val
new_pc = pc_val if pc_val < 0x8000000000000000 else pc_val - 0x10000000000000000
new_pc_words = new_pc // 4
return new_pc_words - st.pc - 1 # -1 because emulator adds inst_words (1 for scalar)
return 0
@@ -260,24 +258,25 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
vdsty = (inst.vdsty << 1) | ((inst.vdstx & 1) ^ 1)
inputs = [(inst.opx, st.rsrc(inst.srcx0, lane), V[inst.vsrcx1], V[inst.vdstx], inst.vdstx),
(inst.opy, st.rsrc(inst.srcy0, lane), V[inst.vsrcy1], V[vdsty], vdsty)]
results = [(dst, fn(s0, s1, 0, d0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'])
for vopd_op, s0, s1, d0, dst in inputs if (op := _VOPD_TO_VOP.get(vopd_op)) and (fn := compiled.get(type(op), {}).get(op))]
for dst, val in results: V[dst] = val
def exec_vopd(vopd_op, s0, s1, d0):
op = _VOPD_TO_VOP[vopd_op]
return compiled[type(op)][op](Reg(s0), Reg(s1), None, Reg(d0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)['D0']._val
for vopd_op, s0, s1, d0, dst in inputs: V[dst] = exec_vopd(vopd_op, s0, s1, d0)
return
# VOP3SD: has extra scalar dest for carry output
if isinstance(inst, VOP3SD):
fn = compiled.get(VOP3SDOp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
fn = compiled[VOP3SDOp][inst.op]
# Read sources based on register counts from inst properties
def rsrc_n(src, regs): return st.rsrc64(src, lane) if regs == 2 else st.rsrc(src, lane)
s0, s1, s2 = rsrc_n(inst.src0, inst.src_regs(0)), rsrc_n(inst.src1, inst.src_regs(1)), rsrc_n(inst.src2, inst.src_regs(2))
# Carry-in ops use src2 as carry bitmask instead of VCC
vcc = st.rsgpr64(inst.src2) if 'CO_CI' in inst.op_name else st.vcc
result = fn(s0, s1, s2, V[inst.vdst], st.scc, vcc, lane, st.exec_mask, st.literal, None, {})
V[inst.vdst] = result['d0'] & MASK32
if result.get('d0_64'): V[inst.vdst + 1] = (result['d0'] >> 32) & MASK32
if result.get('vcc_lane') is not None: st.pend_sgpr_lane(inst.sdst, lane, result['vcc_lane'])
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(V[inst.vdst]), Reg(st.scc), Reg(vcc), lane, Reg(st.exec_mask), st.literal, None)
d0_val = result['D0']._val
V[inst.vdst] = d0_val & MASK32
if inst.dst_regs() == 2: V[inst.vdst + 1] = (d0_val >> 32) & MASK32
if 'VCC' in result: st.pend_sgpr_lane(inst.sdst, lane, (result['VCC']._val >> lane) & 1)
return
# Get op enum and sources (None means "no source" for that operand)
@@ -317,8 +316,7 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
if abs_ & (1<<i): srcs[i] = abs(srcs[i])
if neg & (1<<i): srcs[i] = -srcs[i]
result = srcs[0] * srcs[1] + srcs[2]
V = st.vgpr[lane]
V[inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
st.vgpr[lane][inst.vdst] = _i32(result) if inst.op == VOP3POp.V_FMA_MIX_F32 else _dst16(V[inst.vdst], _i16(result), inst.op == VOP3POp.V_FMA_MIXHI_F16)
return
# VOP3P packed ops: opsel selects halves for lo, opsel_hi for hi; neg toggles f16 sign
raws = [st.rsrc_f16(inst.src0, lane), st.rsrc_f16(inst.src1, lane), st.rsrc_f16(inst.src2, lane) if inst.src2 is not None else 0]
@@ -327,15 +325,13 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
hi_sels = [opsel_hi & 1, opsel_hi & 2, opsel_hi2]
srcs = [((_src16(raws[i], hi_sels[i]) ^ (0x8000 if neg_hi & (1<<i) else 0)) << 16) |
(_src16(raws[i], opsel & (1<<i)) ^ (0x8000 if neg & (1<<i) else 0)) for i in range(3)]
fn = compiled.get(VOP3POp, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op.name} not in pseudocode")
st.vgpr[lane][inst.vdst] = fn(srcs[0], srcs[1], srcs[2], 0, st.scc, st.vcc, lane, st.exec_mask, st.literal, None, {})['d0'] & MASK32
result = compiled[VOP3POp][inst.op](Reg(srcs[0]), Reg(srcs[1]), Reg(srcs[2]), Reg(0), Reg(st.scc), Reg(st.vcc), lane, Reg(st.exec_mask), st.literal, None)
st.vgpr[lane][inst.vdst] = result['D0']._val & MASK32
return
else: raise NotImplementedError(f"Unknown vector type {type(inst)}")
op_cls = type(inst.op)
fn = compiled.get(op_cls, {}).get(inst.op)
if fn is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
if (fn := compiled.get(op_cls, {}).get(inst.op)) is None: raise NotImplementedError(f"{inst.op_name} not in pseudocode")
# Read sources (with VOP3 modifiers if applicable)
neg, abs_ = (getattr(inst, 'neg', 0), getattr(inst, 'abs', 0)) if isinstance(inst, VOP3) else (0, 0)
@@ -377,24 +373,27 @@ def exec_vector(st: WaveState, inst: Inst, lane: int, lds: bytearray | None = No
# Execute compiled function - pass src0_idx and vdst_idx for lane instructions
# For VGPR access: src0 index is the VGPR number (src0 - 256 if VGPR, else src0 for SGPR)
src0_idx = (src0 - 256) if src0 is not None and src0 >= 256 else (src0 if src0 is not None else 0)
result = fn(s0, s1, s2, d0, st.scc, vcc_for_fn, lane, st.exec_mask, st.literal, st.vgpr, {}, src0_idx, vdst)
result = fn(Reg(s0), Reg(s1), Reg(s2), Reg(d0), Reg(st.scc), Reg(vcc_for_fn), lane, Reg(st.exec_mask), st.literal, st.vgpr, src0_idx, vdst)
# Apply results
# Apply results - extract values from returned Reg objects
if 'vgpr_write' in result:
# Lane instruction wrote to VGPR: (lane, vgpr_idx, value)
wr_lane, wr_idx, wr_val = result['vgpr_write']
st.vgpr[wr_lane][wr_idx] = wr_val
if 'vcc_lane' in result:
if 'VCC' in result:
# VOP2 carry ops write to VCC implicitly; VOPC/VOP3 write to vdst
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, result['vcc_lane'])
if 'exec_lane' in result:
# V_CMPX instructions write to EXEC per-lane
st.pend_sgpr_lane(EXEC_LO, lane, result['exec_lane'])
if 'd0' in result and op_cls is not VOPCOp and 'vgpr_write' not in result:
st.pend_sgpr_lane(VCC_LO if isinstance(inst, VOP2) and 'CO_CI' in inst.op_name else vdst, lane, (result['VCC']._val >> lane) & 1)
if 'EXEC' in result:
# V_CMPX instructions write to EXEC per-lane (not to vdst)
st.pend_sgpr_lane(EXEC_LO, lane, (result['EXEC']._val >> lane) & 1)
elif op_cls is VOPCOp:
# VOPC comparison result stored in D0 bitmask, extract lane bit (non-CMPX only)
st.pend_sgpr_lane(vdst, lane, (result['D0']._val >> lane) & 1)
if op_cls is not VOPCOp and 'vgpr_write' not in result:
writes_to_sgpr = 'READFIRSTLANE' in inst.op_name or 'READLANE' in inst.op_name
d0_val = result['d0']
d0_val = result['D0']._val
if writes_to_sgpr: st.wsgpr(vdst, d0_val & MASK32)
elif result.get('d0_64'): V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif inst.dst_regs() == 2: V[vdst], V[vdst + 1] = d0_val & MASK32, (d0_val >> 32) & MASK32
elif inst.is_dst_16(): V[vdst] = _dst16(V[vdst], d0_val, bool(opsel & 8) if isinstance(inst, VOP3) else dst_hi)
else: V[vdst] = d0_val & MASK32

View File

@@ -43,7 +43,10 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
'vscnt', 'vmcnt', 'expcnt', 'lgkmcnt',
'CVT_OFF_TABLE', 'ThreadMask',
'S1[i', 'C.i32', 'S[i]', 'in[',
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST'] # Malformed pseudocode from PDF
'if n.', 'DST.u32', 'addrd = DST', 'addr = DST',
'BARRIER_STATE', 'ReallocVgprs',
'GPR_IDX', 'VSKIP', 'specified in', 'TTBL',
'fp6', 'bf6'] # Malformed pseudocode from PDF
# ═══════════════════════════════════════════════════════════════════════════════
# COMPILER: pseudocode -> Python (minimal transforms)
@@ -51,6 +54,7 @@ UNSUPPORTED = ['SGPR[', 'V_SWAP', 'eval ', 'FATAL_HALT', 'HW_REGISTERS',
def compile_pseudocode(pseudocode: str) -> str:
"""Compile pseudocode to Python. Transforms are minimal - most syntax just works."""
pseudocode = re.sub(r'\bpass\b', 'pass_', pseudocode) # 'pass' is Python keyword
raw_lines = pseudocode.strip().split('\n')
joined_lines: list[str] = []
for line in raw_lines:
@@ -113,7 +117,7 @@ def compile_pseudocode(pseudocode: str) -> str:
break
else:
lhs, rhs = line.split('=', 1)
lhs_s, rhs_s = lhs.strip(), rhs.strip()
lhs_s, rhs_s = _expr(lhs.strip()), rhs.strip()
stmt = _assign(lhs_s, _expr(rhs_s))
if in_first_match_loop and rhs_s == 'i' and (lhs_s == 'tmp' or lhs_s == 'D0.i32'):
stmt += "; break"
@@ -556,52 +560,57 @@ def _apply_pseudocode_fixes(op, code: str) -> str:
def _generate_function(cls_name: str, op, pc: str, code: str) -> tuple[str, str]:
"""Generate a single compiled pseudocode function."""
is_64 = any(p in pc for p in ['D0.u64', 'D0.b64', 'D0.f64', 'D0.i64', 'D1.u64', 'D1.b64', 'D1.f64', 'D1.i64'])
has_d1 = '{ D1' in pc
if has_d1: is_64 = True
is_cmp = (cls_name in ('VOPCOp', 'VOP3Op')) and 'D0.u64[laneId]' in pc
is_cmpx = (cls_name in ('VOPCOp', 'VOP3Op')) and 'EXEC.u64[laneId]' in pc
is_div_scale = 'DIV_SCALE' in op.name
has_sdst = cls_name == 'VOP3SDOp' and ('VCC.u64[laneId]' in pc or is_div_scale)
has_pc = 'PC' in pc
combined = code + pc
fn_name = f"_{cls_name}_{op.name}"
lines = [f"def {fn_name}(s0, s1, s2, d0, scc, vcc, lane, exec_mask, literal, VGPR, _vars, src0_idx=0, vdst_idx=0, pc=0):"]
for pc_line in pc.split('\n'): lines.append(f" # {pc_line}")
# Function accepts Reg objects directly (uppercase names), laneId is passed directly as int
lines = [f"def {fn_name}(S0, S1, S2, D0, SCC, VCC, laneId, EXEC, literal, VGPR, src0_idx=0, vdst_idx=0, PC=None):"]
regs = [('S0', 'Reg(s0)'), ('S1', 'Reg(s1)'), ('S2', 'Reg(s2)'),
('D0', 'Reg(s0)' if is_div_scale else 'Reg(d0)'), ('D1', 'Reg(0)'),
('SCC', 'Reg(scc)'), ('VCC', 'Reg(vcc)'), ('EXEC', 'Reg(exec_mask)'),
('tmp', 'Reg(0)'), ('saveexec', 'Reg(exec_mask)'), ('laneId', 'lane'),
('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)'), ('PC', 'Reg(pc)')]
used = {name for name, _ in regs if name in combined}
if 'EXEC_LO' in combined or 'EXEC_HI' in combined: used.add('EXEC')
if 'VCCZ' in combined: used.add('VCC')
if 'EXECZ' in combined: used.add('EXEC')
for name, init in regs:
if name in used: lines.append(f" {name} = {init}")
if 'EXEC_LO' in combined: lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in combined: lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in combined: lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in combined: lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
lines.append(" # --- compiled pseudocode ---")
for line in code.split('\n'): lines.append(f" {line}")
lines.append(" # --- end pseudocode ---")
d0_val, scc_val = ("D0._val" if 'D0' in used else "d0"), ("SCC._val & 1" if 'SCC' in used else "scc & 1")
lines.append(f" result = {{'d0': {d0_val}, 'scc': {scc_val}}}")
if has_sdst: lines.append(" result['vcc_lane'] = (VCC._val >> lane) & 1")
elif 'VCC' in used: lines.append(" if VCC._val != vcc: result['vcc_lane'] = (VCC._val >> lane) & 1")
if is_cmpx: lines.append(" result['exec_lane'] = (EXEC._val >> lane) & 1")
elif 'EXEC' in used: lines.append(" if EXEC._val != exec_mask: result['exec'] = EXEC._val")
if is_cmp: lines.append(" result['vcc_lane'] = (D0._val >> lane) & 1")
if is_64: lines.append(" result['d0_64'] = True")
if has_d1: lines.append(" result['d1'] = D1._val & 1")
if has_pc:
lines.append(" _pc = PC._val if PC._val < 0x8000000000000000 else PC._val - 0x10000000000000000")
lines.append(" result['new_pc'] = _pc")
lines.append(" return result\n")
# Registers that need special handling (not passed directly)
# Only init if used but not first assigned as `name = Reg(...)` in the compiled code
def needs_init(name): return name in combined and not re.search(rf'^\s*{name}\s*=\s*Reg\(', code, re.MULTILINE)
special_regs = [('D1', 'Reg(0)'), ('SIMM16', 'Reg(literal)'), ('SIMM32', 'Reg(literal)'),
('SRC0', 'Reg(src0_idx)'), ('VDST', 'Reg(vdst_idx)')]
if needs_init('tmp'): special_regs.insert(0, ('tmp', 'Reg(0)'))
if needs_init('saveexec'): special_regs.insert(0, ('saveexec', 'Reg(EXEC._val)'))
used = {name for name, _ in special_regs if name in combined}
# Detect which registers are modified (not just read) - look for assignments
modifies_d0 = is_div_scale or bool(re.search(r'\bD0\b[.\[]', combined))
modifies_exec = is_cmpx or bool(re.search(r'EXEC\.(u32|u64|b32|b64)\s*=', combined))
modifies_vcc = has_sdst or bool(re.search(r'VCC\.(u32|u64|b32|b64)\s*=|VCC\.u64\[laneId\]\s*=', combined))
modifies_scc = bool(re.search(r'\bSCC\s*=', combined))
modifies_pc = bool(re.search(r'\bPC\s*=', combined))
# Build init code for special registers
init_lines = []
if is_div_scale: init_lines.append(" D0 = Reg(S0._val)")
for name, init in special_regs:
if name in used: init_lines.append(f" {name} = {init}")
if 'EXEC_LO' in code: init_lines.append(" EXEC_LO = SliceProxy(EXEC, 31, 0)")
if 'EXEC_HI' in code: init_lines.append(" EXEC_HI = SliceProxy(EXEC, 63, 32)")
if 'VCCZ' in code and not re.search(r'^\s*VCCZ\s*=', code, re.MULTILINE): init_lines.append(" VCCZ = Reg(1 if VCC._val == 0 else 0)")
if 'EXECZ' in code and not re.search(r'^\s*EXECZ\s*=', code, re.MULTILINE): init_lines.append(" EXECZ = Reg(1 if EXEC._val == 0 else 0)")
code_lines = [line for line in code.split('\n') if line.strip()]
if init_lines:
lines.extend(init_lines)
if code_lines: lines.append(" # --- compiled pseudocode ---")
for line in code_lines:
lines.append(f" {line}")
# Build result dict - only include registers that are modified
result_items = []
if modifies_d0: result_items.append("'D0': D0")
if modifies_scc: result_items.append("'SCC': SCC")
if modifies_vcc: result_items.append("'VCC': VCC")
if modifies_exec: result_items.append("'EXEC': EXEC")
if has_d1: result_items.append("'D1': D1")
if modifies_pc: result_items.append("'PC': PC")
lines.append(f" return {{{', '.join(result_items)}}}\n")
return fn_name, '\n'.join(lines)
# ═══════════════════════════════════════════════════════════════════════════════

View File

@@ -229,17 +229,18 @@ class TestPseudocodeRegressions(unittest.TestCase):
"""Regression tests for pseudocode instruction emulation bugs."""
def test_v_div_scale_f32_vcc_always_returned(self):
"""V_DIV_SCALE_F32 must always return vcc_lane, even when VCC=0 (no scaling needed).
Bug: when VCC._val == vcc (both 0), vcc_lane wasn't returned, so VCC bits weren't written.
"""V_DIV_SCALE_F32 must always return VCC, even when VCC=0 (no scaling needed).
Bug: when VCC._val == vcc (both 0), VCC wasn't returned, so VCC bits weren't written.
This caused division to produce wrong results for multiple lanes."""
# Normal case: 1.0 / 3.0, no scaling needed, VCC should be 0
s0 = 0x3f800000 # 1.0
s1 = 0x40400000 # 3.0
s2 = 0x3f800000 # 1.0 (numerator)
result = _VOP3SDOp_V_DIV_SCALE_F32(s0, s1, s2, 0, 0, 0, 0, 0xffffffff, 0, None, {})
# Must always have vcc_lane in result
self.assertIn('vcc_lane', result, "V_DIV_SCALE_F32 must always return vcc_lane")
self.assertEqual(result['vcc_lane'], 0, "vcc_lane should be 0 when no scaling needed")
S0 = Reg(0x3f800000) # 1.0
S1 = Reg(0x40400000) # 3.0
S2 = Reg(0x3f800000) # 1.0 (numerator)
D0, SCC, VCC, EXEC = Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
result = _VOP3SDOp_V_DIV_SCALE_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
# Must always have VCC in result
self.assertIn('VCC', result, "V_DIV_SCALE_F32 must always return VCC")
self.assertEqual(result['VCC']._val & 1, 0, "VCC lane 0 should be 0 when no scaling needed")
def test_v_cmp_class_f32_detects_quiet_nan(self):
"""V_CMP_CLASS_F32 must correctly identify quiet NaN vs signaling NaN.
@@ -248,18 +249,22 @@ class TestPseudocodeRegressions(unittest.TestCase):
signal_nan = 0x7f800001 # signaling NaN: exponent=255, bit22=0
# Test quiet NaN detection (bit 1 in mask)
s1_quiet = 0b0000000010 # bit 1 = quiet NaN
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
self.assertEqual(result['vcc_lane'], 1, "Should detect quiet NaN with quiet NaN mask")
S0, S1, S2, D0, SCC, VCC, EXEC = Reg(quiet_nan), Reg(s1_quiet), Reg(0), Reg(0), Reg(0), Reg(0), Reg(0xffffffff)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 1, "Should detect quiet NaN with quiet NaN mask")
# Test signaling NaN detection (bit 0 in mask)
s1_signal = 0b0000000001 # bit 0 = signaling NaN
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
self.assertEqual(result['vcc_lane'], 1, "Should detect signaling NaN with signaling NaN mask")
S0, S1 = Reg(signal_nan), Reg(s1_signal)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 1, "Should detect signaling NaN with signaling NaN mask")
# Test that quiet NaN doesn't match signaling NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(quiet_nan, s1_signal, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
self.assertEqual(result['vcc_lane'], 0, "Quiet NaN should not match signaling NaN mask")
S0, S1 = Reg(quiet_nan), Reg(s1_signal)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 0, "Quiet NaN should not match signaling NaN mask")
# Test that signaling NaN doesn't match quiet NaN mask
result = _VOPCOp_V_CMP_CLASS_F32(signal_nan, s1_quiet, 0, 0, 0, 0, 0, 0xffffffff, 0, None, {})
self.assertEqual(result['vcc_lane'], 0, "Signaling NaN should not match quiet NaN mask")
S0, S1 = Reg(signal_nan), Reg(s1_quiet)
result = _VOPCOp_V_CMP_CLASS_F32(S0, S1, S2, D0, SCC, VCC, 0, EXEC, 0, None)
self.assertEqual(result['D0']._val & 1, 0, "Signaling NaN should not match quiet NaN mask")
def test_isnan_with_typed_view(self):
"""_isnan must work with TypedView objects, not just Python floats.

View File

@@ -1,6 +1,7 @@
// ** global buffers
s_load_dwordx2 s[28:29], s[0:1], 0x0 // C
s_load_dwordx4 s[32:35], s[0:1], 0x8 // A, B
s_load_dwordx2 s[34:35], s[0:1], 0x08 // A
s_load_dwordx2 s[32:33], s[0:1], 0x10 // B
// ** others kernel args
s_load_dword s24, s[0:1], 0x18 // N
s_load_dword s54, s[0:1], 0x1C // num work groups

View File

@@ -52,7 +52,7 @@ def get_asm_prg() -> ProgramSpec:
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(B).uop.buffer, from_torch(A).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
prg=CompiledRunner(get_asm_prg())))
with Context(DEBUG=2):

View File

@@ -1,12 +1,12 @@
# unpack the complete kernel descriptor of an amdgpu ELF of for gfx950
# unpack the complete kernel descriptor of an amdgpu ELF
# https://rocm.docs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPUUsage.html#code-object-v3-kernel-descriptor
import struct, pathlib
import struct, pathlib, sys
from tinygrad.runtime.support.elf import elf_loader
def bits(x, lo, hi): return (x >> lo) & ((1 << (hi - lo + 1)) - 1)
def assert_zero(x, lo, hi): assert bits(x, lo, hi) == 0
with open(fp:=pathlib.Path(__file__).parent/"lib", "rb") as f:
with open(sys.argv[1], "rb") as f:
lib = f.read()
image, sections, relocs = elf_loader(lib)
@@ -49,7 +49,7 @@ print("COMPUTE_PGM_RSRC3: 0x%08x" % pgm_rsrc3)
print("COMPUTE_PGM_RSRC1: 0x%08x" % pgm_rsrc1)
print("COMPUTE_PGM_RSRC2: 0x%08x" % pgm_rsrc2)
# rsrc 3
# rsrc 3 (gfx950)
accum_offset_raw = bits(pgm_rsrc3, 0, 5)
assert_zero(pgm_rsrc3, 6, 15)
@@ -169,10 +169,10 @@ assert_zero(desc, 458, 459)
uses_dynamic_stack = bits(desc, 459, 460)
print("DESC.USES_DYNAMIC_STACK:", uses_dynamic_stack)
# gfx950 only
assert_zero(desc, 460, 463)
kernarg_preload_spec_length = bits(desc, 464, 470)
print("DESC.KERNARG_PRELOAD_SPEC_LENGTH:", kernarg_preload_spec_length)
kernarg_preload_spec_offset = bits(desc, 471, 479)
print("DESC.KERNARG_PRELOAD_SPEC_OFFSET:", kernarg_preload_spec_offset)

View File

@@ -1,9 +1,13 @@
import unittest
import numpy as np
from tinygrad.helpers import BEAM, Timing, CI, Context
from tinygrad import Variable, Tensor
from tinygrad.helpers import BEAM, Timing, CI, prod
from tinygrad import Variable, Device, Tensor
from tinygrad.nn import Conv2d
from tinygrad.uop.ops import AxisType
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.opt.postrange import Scheduler
from tinygrad.codegen.opt.search import get_kernel_actions
def rand(*shape):
return Tensor(np.random.rand(*shape).astype(np.float32))
@@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase):
a = (a + a) * a
a.realize()
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tc_up(self):
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
size = max(tc.dims[0], tc.dims[1]) * 8
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
actions = get_kernel_actions(s, include_0=False, max_up=int(up))
upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]
assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none"
def test_max_up(self):
a = Tensor.rand(16, 16)
ast = a.schedule()[-1].ast
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
for max_up in (2, 4):
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]:
assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0
if __name__ == '__main__':
unittest.main()

View File

@@ -3,18 +3,21 @@
import numpy as np
import unittest
import subprocess, struct, math, textwrap
import subprocess, struct, math, textwrap, functools
from tinygrad import Tensor, dtypes, Device, UOp
from tinygrad.uop.ops import Ops
from tinygrad.uop.ops import Ops, KernelInfo
from tinygrad.helpers import getenv
from tinygrad.runtime.support.compiler_amd import amdgpu_disassemble
from tinygrad.renderer import ProgramSpec
from tinygrad.engine.realize import CompiledRunner
from extra.assembly.amd.autogen.rdna3.ins import *
from extra.assembly.amd.asm import waitcnt
from test.testextra.test_cfg_viz import template
def custom_src(out:UOp, src:str, device:str, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name="test"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
out = Tensor([0]*n_threads, dtype=dtypes.uint32).realize()
src = "\n".join(inst.disasm() for inst in [
@@ -26,11 +29,9 @@ def get_output(asm:list, n_threads:int=1, vdst:VGPR=v[1]):
global_store_b32(addr=v[0], data=vdst, saddr=s[0:1]),
s_endpgm()
])
prg = ProgramSpec("test", template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src)), Device.DEFAULT, UOp(Ops.SINK),
global_size=[1, 1, 1], local_size=[n_threads, 1, 1], globals=[0])
car = CompiledRunner(prg)
if getenv("PRINT_ASM"): amdgpu_disassemble(car.lib)
car([out.uop.buffer], {}, wait=True)
src = template.replace("fn_name", "test").replace("INSTRUCTION", textwrap.dedent(src))
out = Tensor.custom_kernel(out, fxn=functools.partial(custom_src, src=src, device=out.device, n_threads=n_threads))[0]
out.realize()
return out.tolist()
def f16_to_bits(x:float) -> int: return struct.unpack('<H', struct.pack('<e', x))[0]

View File

@@ -135,9 +135,14 @@ check_untyped_defs = true
explicit_package_bases = true
warn_unreachable = true
warn_redundant_casts = true
strict_equality = true
# NOTE: had to comment this out to make mypy pass on both CI and OSX
#warn_unused_ignores = true
[[tool.mypy.overrides]]
module = "extra.*"
follow_imports = "skip"
[tool.pytest.ini_options]
norecursedirs = [
"extra",

View File

@@ -1,74 +0,0 @@
#!/usr/bin/env python
import unittest
from tinygrad.device import Device, BufferSpec
from tinygrad.dtype import dtypes
@unittest.skipUnless(Device.DEFAULT == "QCOM", "QCOM device required to run")
class TestQcom(unittest.TestCase):
def test_image_pitch(self):
dev = Device["QCOM"]
def __validate(imgdt, expected_pitch):
img = dev.allocator.alloc(imgdt.shape[0] * imgdt.shape[1] * 16, options:=BufferSpec(image=imgdt))
pitch = img.texture_info.pitch
assert pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{pitch:X}, expected 0x{expected_pitch:X}"
dev.allocator.free(img, imgdt.shape[0] * imgdt.shape[1] * 16, options)
# Match opencl pitches for perf
__validate(dtypes.imageh((1, 201)), 0x680)
__validate(dtypes.imageh((16, 216)), 0x700)
__validate(dtypes.imageh((16, 9)), 0x80)
__validate(dtypes.imageh((48, 64)), 0x200)
__validate(dtypes.imageh((32, 128)), 0x400)
__validate(dtypes.imageh((96, 128)), 0x400)
__validate(dtypes.imageh((64, 256)), 0x840)
__validate(dtypes.imageh((64, 9)), 0x80)
__validate(dtypes.imageh((192, 256)), 0x840)
__validate(dtypes.imageh((64, 768)), 0x1840)
__validate(dtypes.imageh((256, 49)), 0x1C0)
__validate(dtypes.imageh((128, 9)), 0x80)
__validate(dtypes.imageh((16, 1024)), 0x2080)
__validate(dtypes.imageh((64, 512)), 0x1040)
__validate(dtypes.imageh((16, 512)), 0x1080)
__validate(dtypes.imageh((132, 64)), 0x200)
__validate(dtypes.imageh((4, 512)), 0x1200)
__validate(dtypes.imageh((8, 512)), 0x1100)
__validate(dtypes.imageh((128, 128)), 0x400)
__validate(dtypes.imageh((32, 512)), 0x1040)
__validate(dtypes.imageh((26, 64)), 0x200)
__validate(dtypes.imageh((32, 516)), 0x1040)
__validate(dtypes.imageh((32, 1024)), 0x2040)
__validate(dtypes.imageh((16, 2048)), 0x4080)
__validate(dtypes.imageh((8, 2048)), 0x4100)
__validate(dtypes.imageh((4, 4096)), 0x8200)
__validate(dtypes.imagef((16, 49)), 0x380)
__validate(dtypes.imagef((16, 1024)), 0x4080)
__validate(dtypes.imagef((256, 64)), 0x400)
__validate(dtypes.imagef((64, 512)), 0x2040)
__validate(dtypes.imagef((16, 512)), 0x2080)
__validate(dtypes.imagef((132, 64)), 0x400)
__validate(dtypes.imagef((4, 512)), 0x2200)
__validate(dtypes.imagef((4, 16)), 0x200)
__validate(dtypes.imagef((2, 16)), 0x400)
__validate(dtypes.imagef((8, 512)), 0x2100)
__validate(dtypes.imagef((12, 64)), 0x400)
__validate(dtypes.imagef((3, 32)), 0x400)
__validate(dtypes.imagef((128, 128)), 0x840)
__validate(dtypes.imagef((32, 512)), 0x2040)
__validate(dtypes.imagef((8, 3072)), 0xC100)
__validate(dtypes.imagef((4, 2048)), 0x8200)
__validate(dtypes.imagef((4, 1024)), 0x4200)
__validate(dtypes.imagef((4, 4096)), 0x10200)
__validate(dtypes.imagef((10, 384)), 0x1900)
__validate(dtypes.imagef((24, 64)), 0x400)
__validate(dtypes.imagef((128, 12)), 0xC0)
__validate(dtypes.imagef((10, 24)), 0x200)
__validate(dtypes.imagef((1, 129)), 0x840)
__validate(dtypes.imagef((1, 32)), 0x200)
__validate(dtypes.imagef((1, 64)), 0x400)
__validate(dtypes.imagef((1, 1239)), 0x4D80)
__validate(dtypes.imagef((1, 1)), 0x40)
if __name__ == "__main__":
unittest.main()

View File

@@ -44,6 +44,66 @@ class TestImageCopy(unittest.TestCase):
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
class TestImageDType(unittest.TestCase):
def test_image_pitch(self):
def __validate(imgdt, expected_pitch):
assert imgdt.pitch == expected_pitch, f"Failed pitch for image: {imgdt}. Got 0x{imgdt.pitch:X}, expected 0x{expected_pitch:X}"
# Match opencl pitches for perf
__validate(dtypes.imageh((1, 201)), 0x680)
__validate(dtypes.imageh((16, 216)), 0x700)
__validate(dtypes.imageh((16, 9)), 0x80)
__validate(dtypes.imageh((48, 64)), 0x200)
__validate(dtypes.imageh((32, 128)), 0x400)
__validate(dtypes.imageh((96, 128)), 0x400)
__validate(dtypes.imageh((64, 256)), 0x840)
__validate(dtypes.imageh((64, 9)), 0x80)
__validate(dtypes.imageh((192, 256)), 0x840)
__validate(dtypes.imageh((64, 768)), 0x1840)
__validate(dtypes.imageh((256, 49)), 0x1C0)
__validate(dtypes.imageh((128, 9)), 0x80)
__validate(dtypes.imageh((16, 1024)), 0x2080)
__validate(dtypes.imageh((64, 512)), 0x1040)
__validate(dtypes.imageh((16, 512)), 0x1080)
__validate(dtypes.imageh((132, 64)), 0x200)
__validate(dtypes.imageh((4, 512)), 0x1200)
__validate(dtypes.imageh((8, 512)), 0x1100)
__validate(dtypes.imageh((128, 128)), 0x400)
__validate(dtypes.imageh((32, 512)), 0x1040)
__validate(dtypes.imageh((26, 64)), 0x200)
__validate(dtypes.imageh((32, 516)), 0x1040)
__validate(dtypes.imageh((32, 1024)), 0x2040)
__validate(dtypes.imageh((16, 2048)), 0x4080)
__validate(dtypes.imageh((8, 2048)), 0x4100)
__validate(dtypes.imageh((4, 4096)), 0x8200)
__validate(dtypes.imagef((16, 49)), 0x380)
__validate(dtypes.imagef((16, 1024)), 0x4080)
__validate(dtypes.imagef((256, 64)), 0x400)
__validate(dtypes.imagef((64, 512)), 0x2040)
__validate(dtypes.imagef((16, 512)), 0x2080)
__validate(dtypes.imagef((132, 64)), 0x400)
__validate(dtypes.imagef((4, 512)), 0x2200)
__validate(dtypes.imagef((4, 16)), 0x200)
__validate(dtypes.imagef((2, 16)), 0x400)
__validate(dtypes.imagef((8, 512)), 0x2100)
__validate(dtypes.imagef((12, 64)), 0x400)
__validate(dtypes.imagef((3, 32)), 0x400)
__validate(dtypes.imagef((128, 128)), 0x840)
__validate(dtypes.imagef((32, 512)), 0x2040)
__validate(dtypes.imagef((8, 3072)), 0xC100)
__validate(dtypes.imagef((4, 2048)), 0x8200)
__validate(dtypes.imagef((4, 1024)), 0x4200)
__validate(dtypes.imagef((4, 4096)), 0x10200)
__validate(dtypes.imagef((10, 384)), 0x1900)
__validate(dtypes.imagef((24, 64)), 0x400)
__validate(dtypes.imagef((128, 12)), 0xC0)
__validate(dtypes.imagef((10, 24)), 0x200)
__validate(dtypes.imagef((1, 129)), 0x840)
__validate(dtypes.imagef((1, 32)), 0x200)
__validate(dtypes.imagef((1, 64)), 0x400)
__validate(dtypes.imagef((1, 1239)), 0x4D80)
__validate(dtypes.imagef((1, 1)), 0x40)
def test_image_and_back(self):
data = Tensor.randn(9*27*4).realize()
tst = data.numpy()

View File

@@ -288,17 +288,22 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=0)
def test_conv2d_ceildiv_edge_case(self):
v = Variable('v', 11, 50_000)
val = 39601
x = Tensor.randn(1, 22, 50_000)[:, :, :v.bind(val)]
weight = Tensor.randn(256, 22, 12)
# tests symbolic ceildiv in conv2d output shape calculation
# val=79 triggers the edge case where old ceildiv simplifies incorrectly: old gives floor=12, correct ceildiv=13
v = Variable('v', 11, 100)
val = 79
x_full = Tensor.randn(1, 8, 100)
weight = Tensor.randn(16, 8, 12)
result = x.conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
# symbolic version
result = x_full[:, :, :v.bind(val)].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
var_val = {v.expr: val}
shape = tuple(sym_infer(s, var_val) for s in result.shape)
with self.assertRaises(AssertionError):
self.assertEqual(shape, (1, 256, 6600)) # TODO: fails if ceildiv is incorrect
# TODO: test output is correct
self.assertEqual(shape, (1, 16, 13))
# concrete version for comparison
expected = x_full[:, :, :val].conv2d(weight=weight, groups=1, stride=6, dilation=1, padding=(3, 3))
np.testing.assert_allclose(result[:, :, :13].numpy(), expected.numpy(), atol=1e-5, rtol=1e-5)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,117 +0,0 @@
#!/usr/bin/env python
import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.realize import run_schedule
from tinygrad.uop.ops import UOp
from tinygrad.helpers import SPLIT_REDUCEOP
class TestTensorUOp(unittest.TestCase):
def test_fromcpu_shape_tracker(self):
def helper(a: np.ndarray):
print(a.shape, a.strides, a.flags.c_contiguous)
b = Tensor(a).uop
assert b.shape == a.shape
np.testing.assert_equal(a, Tensor(b).numpy())
for ndims in range(1, 4):
a = np.random.randn(*(4,)*ndims).astype(np.float32)
for stride in [-2, 1, 2]:
for start in [0, 1]:
helper(a[(slice(start, None, stride),)*ndims])
def test_shuffle_pad_ops_cmpeq(self):
y = Tensor([1]).cat(Tensor([1]) == 0).numpy()
z = Tensor([1, 0]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_div(self):
y = Tensor([1]).cat(Tensor([1]).div(Tensor([2.0]))).numpy()
z = Tensor([1, 0.5]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_log(self):
y = Tensor([1]).cat(Tensor([1]).log()).numpy()
z = Tensor([1, 0]).numpy()
np.testing.assert_allclose(y, z)
def test_shuffle_pad_ops_exp(self):
y = Tensor([1]).cat(Tensor([1]).exp()).numpy()
z = Tensor([1, np.e]).numpy()
np.testing.assert_allclose(y, z)
def test_device_0_is_the_same_device(self):
a = Tensor([1, 2, 3], f"{Device.DEFAULT}")
b = Tensor([1, 2, 3], f"{Device.DEFAULT}:0")
assert a.device == b.device
def test_shrink_const_into_zero(self):
# regression test to make sure the shapetracker is preserved
a = Tensor.zeros(4,4,4).shrink((None, (0,0), None))
b = Tensor.zeros(4,1,4)
c = a.cat(b, dim=1)
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
def test_shrink_const_then_cast(self):
# regression test to make sure the shapetracker is preserved
a = Tensor.zeros(4,4,4).shrink((None, (0,0), None)).cast(dtypes.int32)
b = Tensor.zeros(4,1,4)
c = a.cat(b, dim=1)
np.testing.assert_allclose(c.numpy(), np.concatenate((a.numpy(), b.numpy()), axis=1))
def test_const_dtype(self):
lb: UOp = Tensor([1], dtype=dtypes.int).uop
assert lb.const_like(1).base.arg == 1
assert type(lb.const_like(1).base.arg) is int
lb: UOp = Tensor([1], dtype=dtypes.float).uop
assert lb.const_like(1).base.arg == 1.0
assert type(lb.const_like(1).base.arg) is float
def test_contiguous_alu(self):
a = Tensor.randn(2, 2).realize()
b = Tensor.randn(2, 2).realize()
add = (a+b).contiguous()
out = add+2
sched = out.schedule()
self.assertEqual(len(sched), 2)
run_schedule(sched)
np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2)
# NOTE: contiguous on a buffer collapses
@unittest.skip("contiguous on a buffer no longer collapses")
def test_contiguous_empty(self):
empty = Tensor.empty(1).contiguous()
sched = empty.schedule()
self.assertEqual(len(sched), 0)
def test_contiguous_folded_alu(self):
a = Tensor.empty(8, 8)
# NOTE: the buffer for mul_0 late folds to just a CONST
mul_0 = a*0
out = mul_0.shrink(((4, 8), (0, 8))).contiguous()
out.realize()
self.assertEqual(out.tolist(), Tensor.zeros(4, 8).tolist())
@unittest.skipUnless(SPLIT_REDUCEOP, "only for SPLIT_REDUCEOP")
class TestReduceOp(unittest.TestCase):
def test_no_split_reduce_kernel(self):
a = Tensor.rand(4, 4).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 1
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 2
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 2
if __name__ == "__main__":
unittest.main()

View File

@@ -73,8 +73,6 @@ class TestTensorVariable(unittest.TestCase):
ret = Tensor.arange(vv.bind(4), 7)
self.assertListEqual(ret[:3].tolist(), [4,5,6])
# TODO: add vmin/vmax pattern for symbolic denominator
@unittest.expectedFailure
def test_symbolic_arange_sym_step(self):
vv = Variable("step", 1, 3)
ret = Tensor.arange(0, 10, vv.bind(2))
@@ -86,6 +84,18 @@ class TestTensorVariable(unittest.TestCase):
ret = Tensor.arange(begin.bind(4), end.bind(7))
self.assertListEqual(ret[:3].tolist(), [4,5,6])
def test_symbolic_arange_three_vars(self):
begin = Variable("b", 0, 5)
end = Variable("e", 10, 20)
step = Variable("s", 1, 3)
ret = Tensor.arange(begin.bind(2), end.bind(14), step.bind(3))
self.assertListEqual(ret[:4].tolist(), [2,5,8,11])
def test_symbolic_full(self):
vv = Variable("x", 1, 10).bind(5)
t = Tensor.full((3,), vv)
self.assertListEqual(t.tolist(), [5,5,5])
def test_variable_empty(self):
v = Variable("i", 1, 10)
# TODO: Tensor creation from unbound variable should assert

View File

@@ -101,6 +101,44 @@ class TestFromFuzzer(unittest.TestCase):
_test_value(0)
_test_value(0.0000009)
class TestFloat16Log2(unittest.TestCase):
"""Tests for native float16 log2 implementation (no float32 cast)"""
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_basic(self):
# basic values
test_values = [1.0, 2.0, 4.0, 0.5, 0.25, 10.0, 100.0, 1000.0]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
np.testing.assert_allclose(result, expected, rtol=1e-3, err_msg=f"log2({val})")
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and CI, "Nan handling differs on Vulkan")
def test_float16_log2_special(self):
# special values: inf, -inf, nan, 0, negative
with Context(TRANSCENDENTAL=2), np.errstate(all='ignore'):
# log2(inf) = inf
assert np.isinf(Tensor([np.inf], dtype=dtypes.float16).log2().numpy()[0])
# log2(0) = -inf
assert Tensor([0.0], dtype=dtypes.float16).log2().numpy()[0] == -np.inf
# log2(negative) = nan
assert np.isnan(Tensor([-1.0], dtype=dtypes.float16).log2().numpy()[0])
# log2(nan) = nan
assert np.isnan(Tensor([np.nan], dtype=dtypes.float16).log2().numpy()[0])
@unittest.skipUnless(is_dtype_supported(dtypes.float16, Device.DEFAULT), f"no float16 on {Device.DEFAULT}")
def test_float16_log2_denormal(self):
# test values near and below float16 min normal (6.1e-5)
# these exercise the denormal handling path with 2^10 scaling
test_values = [1e-4, 6e-5, 1e-5]
with Context(TRANSCENDENTAL=2):
for val in test_values:
result = Tensor([val], dtype=dtypes.float16).log2().numpy()[0]
expected = np.log2(np.float16(val))
# denormals have lower precision due to float16 limitations
np.testing.assert_allclose(result, expected, rtol=5e-2, err_msg=f"log2({val})")
class TestTranscendentalSchedule(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.ulong), "Needs ulong")
def test_transcendental_sin_fusion(self):

View File

@@ -478,143 +478,6 @@ class TestUOpGraph(unittest.TestCase):
for u in uops:
self.assertNotEqual(u.dtype, dtypes.long)
def test_in_out_of_bounds_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 0), ptr=True),))
to_uops_list([ld0])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 15), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 7), ptr=True),))
to_uops_list([ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 42), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 1, 10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(Variable("i", 0, 20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_gated_store(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), src=(), arg=0)
v = Variable("v", 0, 20)
st0 = UOp(Ops.STORE, dtypes.void, src=(glbl0.index(v.valid(v<16)), UOp.const(dtypes.int, 0)))
to_uops_list([st0])
st1 = UOp(Ops.STORE, dtypes.void, (glbl0.index(v.valid(v<20)), v))
with self.assertRaises(RuntimeError): to_uops_list([st1])
@unittest.skip("if not allowed in graph")
def test_in_bounds_access_gated_local(self):
with Context(IGNORE_OOB=0):
# Define buffers
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
to_uops_list([global_store])
def test_load_with_float_in_index(self):
with Context(IGNORE_OOB=0):
ridx = UOp.range(20, 0)
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = (ridx.cast(dtypes.float)*0.68).trunc().cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid((0<=i)&(i<16)), ptr=True),))
to_uops_list([ld0])
glblfloat = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(20), (), 0)
ldfloat = UOp(Ops.LOAD, dtypes.float, (glblfloat.index(ridx),))
i = (ldfloat+3.14).cast(dtypes.int)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i, ((0<=i)&(i<16)), ptr=True),))
def test_load_cast_to_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(ridx.valid(ridx.cast(dtypes.bool).logical_not()), ptr=True),))
to_uops_list([ld0])
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
def test_out_of_bounds_off_by_one_access(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(dtypes.int, 16), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_bounds_access_with_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid((5<gidx0)&(gidx0<16)), ptr=True),))
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<16), ptr=True),))
to_uops_list([ld0, ld1])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<17), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_symbolic_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
i = Variable("i", 1, 80)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<10), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<15), ptr=True),))
to_uops_list([ld0])
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(i.valid(i<20), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld0])
def test_in_out_of_bounds_access_index_load(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
gidx0 = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(gidx0.valid(gidx0<8), ptr=True),)).cast(dtypes.index)
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<32)), ptr=True),))
to_uops_list([ld1])
ld1 = UOp(Ops.LOAD, dtypes.int, (glbl1.index((ld0*2).valid((ld0>=0)&(ld0<64)), ptr=True),))
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_bounds_with_loaded_bool(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 0)
gidx0 = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld0 = glbl0.index(gidx0, ptr=True).load()
ld1 = glbl1.index(gidx0.valid(ld0), ptr=True).load()
with self.assertRaises(RuntimeError): to_uops_list([ld1])
def test_fold_gated_load(self):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)

View File

@@ -2,6 +2,7 @@ import ctypes, gzip, unittest, timeit, pickle
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar, argfix, colored, word_wrap, is_numpy_ndarray, mv_address, get_contraction, count
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.helpers import ceildiv
from tinygrad.tensor import Tensor, get_shape
import numpy as np
@@ -120,6 +121,25 @@ class TestRoundUp(unittest.TestCase):
self.assertEqual(round_up(232, 24984), 24984)
self.assertEqual(round_up(24984, 232), 25056)
class TestCeilDiv(unittest.TestCase):
def test_int(self):
self.assertEqual(ceildiv(10, 3), 4)
self.assertEqual(ceildiv(9, 3), 3)
self.assertEqual(ceildiv(0, 5), 0)
self.assertEqual(ceildiv(1, 5), 1)
def test_symbolic(self):
# tests that ceildiv with UOp uses (num + amt - 1) // amt formula for non-negative num
v = Variable('v', 0, 100)
result = ceildiv(v, 6)
self.assertEqual(result.render(), "((v+5)//6)")
def test_symbolic_negative_offset(self):
# tests ceildiv(v-5, 6) which is used in conv2d output shape
# old implementation incorrectly simplified -(x//-y) to ((v+1)//6-1) for v-5
# new implementation uses (v-5+5)//6 = v//6 which is correct
v = Variable('v', 11, 100)
result = ceildiv(v - 5, 6)
self.assertEqual(result.render(), "(v//6)")
class TestCount(unittest.TestCase):
def test_count_basic(self):
c = count(3)

View File

@@ -65,6 +65,30 @@ class TestLinAlg(unittest.TestCase):
orthogonality_helper(Q)
reconstruction_helper([Q,R],a)
def test_qr_zero_column(self):
a = Tensor([[0.0, 1.0], [0.0, 2.0]]).realize()
Q,R = a.qr()
assert not np.isnan(Q.numpy()).any()
assert not np.isnan(R.numpy()).any()
orthogonality_helper(Q)
reconstruction_helper([Q,R], a)
def test_svd_identity(self):
for a in (Tensor.eye(2), Tensor.zeros(2, 2)):
a = a.realize()
U,S,V = a.svd()
assert not np.isnan(U.numpy()).any()
assert not np.isnan(S.numpy()).any()
assert not np.isnan(V.numpy()).any()
s_diag = (S.unsqueeze(-2) * Tensor.eye(2))
reconstruction_helper([U, s_diag, V], a)
def test_svd_rank1(self):
a = Tensor([[1.0, 1.0], [2.0, 2.0]]).realize()
U, S, V = a.svd()
np.testing.assert_allclose(S.numpy(), [np.sqrt(10), 0.0], atol=1e-4, rtol=1e-4)
reconstruction_helper([U, S.unsqueeze(-2) * Tensor.eye(2), V], a)
def test_newton_schulz(self):
coefficients = [(2, -1.5, 0.5), (2.0, -1.4, 0.2, 0.2)]#these params map to the sign function
sizes = [(2,2), (3,2), (2,3), (2,2,2)]

View File

@@ -1,13 +1,12 @@
import unittest
from tinygrad.tensor import Tensor
class TestMaskedShapeTracker(unittest.TestCase):
class TestMaskedTensor(unittest.TestCase):
def test_mul_masked(self):
a = Tensor([1,1,1,1,1])
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
@@ -16,7 +15,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
b = Tensor([1,1]).pad(((0,3),))
c = a*b
assert c.shape == a.shape
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [1.0, 1.0, 0.0, 0.0, 0.0]
@@ -24,7 +22,6 @@ class TestMaskedShapeTracker(unittest.TestCase):
a = Tensor([1,1]).pad(((0,2),))
b = Tensor([1,1]).pad(((0,2),))
c = a+b
#assert c.uop.st.views[0].mask is not None
ret = c.data()
assert ret.tolist() == [2.0, 2.0, 0.0, 0.0]

View File

@@ -128,6 +128,26 @@ class TestProgressBar(unittest.TestCase):
self._compare_bars(tinytqdm_output, tqdm_output)
if n > 5: break
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')
def test_si_boundary(self, mock_terminal_size, mock_stderr):
"""Test SI formatting at boundaries (e.g., 999.5 -> 1.00k, not 1000)"""
ncols = 80
mock_terminal_size.return_value = namedtuple(field_names='columns', typename='terminal_size')(ncols)
# Test rates at the boundary: 999 stays as "999", 999.5+ becomes "1.00k"
for rate in [999, 999.4, 999.5, 1000, 1001]:
mock_stderr.truncate(0)
mock_stderr.seek(0)
elapsed = 1.0 / rate
# Need 3 perf_counter calls: init st, init update, final update
with patch('time.perf_counter', side_effect=[0, 0, elapsed]):
bar = tinytqdm(desc="Test", total=1, unit_scale=True, rate=10**9)
bar.update(1, close=True)
tinytqdm_output = mock_stderr.getvalue().split("\r")[-1].rstrip()
tqdm_output = tqdm.format_meter(n=1, total=1, elapsed=elapsed, ncols=ncols, prefix="Test", unit_scale=True)
self._compare_bars(tinytqdm_output, tqdm_output)
@unittest.skip("this is flaky")
@patch('sys.stderr', new_callable=StringIO)
@patch('shutil.get_terminal_size')

View File

@@ -0,0 +1,179 @@
import unittest
from tinygrad import dtypes, Variable
from tinygrad.dtype import AddrSpace
from tinygrad.helpers import Context
from tinygrad.uop.ops import Ops, UOp, AxisType
from test.test_uops import to_uops_list
class TestValidateOOB(unittest.TestCase):
"""Test z3 validation of index bounds for different ALU ops and patterns."""
# basic index patterns
def test_const_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.const(dtypes.int, 0), ptr=True).load(dtype=dtypes.int)]) # valid
to_uops_list([buf.index(UOp.const(dtypes.int, 15), ptr=True).load(dtype=dtypes.int)]) # valid (last element)
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.const(dtypes.int, 16), ptr=True).load(dtype=dtypes.int)]) # off by one
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.const(dtypes.int, 42), ptr=True).load(dtype=dtypes.int)]) # way out
def test_variable_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(Variable("i", 0, 15), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("i", 0, 20), ptr=True).load(dtype=dtypes.int)]) # oob
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("i", -5, 10), ptr=True).load(dtype=dtypes.int)]) # negative
def test_range_with_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(42, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid(r < 16), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r.valid(r < 17), ptr=True).load(dtype=dtypes.int)]) # oob
def test_variable_with_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
v = Variable("v", -5, 80)
to_uops_list([buf.index(v.valid((v >= 0) & (v < 16)), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(v.valid(v < 20), ptr=True).load(dtype=dtypes.int)]) # negative not masked
def test_gated_store(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
v = Variable("v", 0, 20)
to_uops_list([buf.index(v.valid(v < 16)).store(0)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(v.valid(v < 20)).store(0)]) # oob
# ALU ops in index
def test_idiv(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.range(32, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(34, 0, AxisType.GLOBAL) // 2, ptr=True).load(dtype=dtypes.int)]) # 0..16 oob
def test_mod(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r % 16, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r % 20, ptr=True).load(dtype=dtypes.int)]) # 0..19 oob
def test_shr(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(UOp.range(64, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(UOp.range(128, 0, AxisType.GLOBAL) >> 2, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
def test_shl(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 0)
r = UOp.range(8, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r << 2, ptr=True).load(dtype=dtypes.int)]) # 0..28 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r << 4, ptr=True).load(dtype=dtypes.int)]) # 0..112 oob
def test_and(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(100, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r & 15, ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r & 31, ptr=True).load(dtype=dtypes.int)]) # 0..31 oob
def test_max(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
to_uops_list([buf.index(Variable("v", -10, 15).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(Variable("v2", -10, 20).maximum(0), ptr=True).load(dtype=dtypes.int)]) # 0..20 oob
def test_xor_in_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(32, 0, AxisType.GLOBAL)
to_uops_list([buf.index(r.valid((r < 8) ^ ((r >= 8) & (r < 16))), ptr=True).load(dtype=dtypes.int)]) # 0..15 valid
with self.assertRaises(RuntimeError):
to_uops_list([buf.index(r.valid((r < 10) ^ (r >= 20)), ptr=True).load(dtype=dtypes.int)]) # 0..9,20..31 oob
# cast patterns
def test_float_cast_in_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
r = UOp.range(20, 0)
i = (r.cast(dtypes.float) * 0.68).trunc().cast(dtypes.int)
to_uops_list([buf.index(i.valid((i >= 0) & (i < 16)), ptr=True).load(dtype=dtypes.int)])
def test_bool_cast_in_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(1), (), 0)
r = UOp.range(20, 0)
to_uops_list([buf.index(r.valid(r.cast(dtypes.bool).logical_not()), ptr=True).load(dtype=dtypes.int)]) # only r=0 valid
# load result as index/mask
def test_load_as_index(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
buf1 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(64), (), 1)
r = UOp.range(42, 0, AxisType.GLOBAL)
ld0 = buf0.index(r.valid(r < 8), ptr=True).load(dtype=dtypes.int).cast(dtypes.index)
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 32)), ptr=True).load(dtype=dtypes.int)]) # valid
with self.assertRaises(RuntimeError):
to_uops_list([buf1.index((ld0 * 2).valid((ld0 >= 0) & (ld0 < 64)), ptr=True).load(dtype=dtypes.int)]) # oob
def test_load_bool_as_mask(self):
with Context(IGNORE_OOB=0, SPEC=2):
buf_bool = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
buf_int = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(8), (), 1)
gidx = UOp(Ops.SPECIAL, dtypes.index, (UOp.const(dtypes.index, 16),), "gidx0")
ld_bool = buf_bool.index(gidx, ptr=True).load()
with self.assertRaises(RuntimeError):
to_uops_list([buf_int.index(gidx.valid(ld_bool), ptr=True).load()]) # gidx 0..15, buf_int size 8
# skipped tests (moved from test_uop_graph.py)
@unittest.skip("if not allowed in graph")
def test_in_bounds_access_gated_local(self):
with Context(IGNORE_OOB=0):
# Define buffers
gbuf = UOp(Ops.DEFINE_GLOBAL, dtypes.uint.ptr(400), (), 0)
sbuf = UOp(Ops.DEFINE_LOCAL, dtypes.uint.ptr(8, addrspace=AddrSpace.LOCAL), (), "temp0")
# Define indices, valids and barrier
gidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 416),), "gidx0")
lidx = UOp(Ops.SPECIAL, dtypes.int, (UOp.const(dtypes.int, 10),), "lidx0")
gate = (gidx<400) & (lidx<8)
local_store = UOp(Ops.STORE, dtypes.void, (sbuf.index(lidx, lidx<8), UOp.const(dtypes.uint, 1)))
barrier = UOp(Ops.BARRIER, dtypes.void, (local_store,))
if_barrier = UOp(Ops.IF, dtypes.void, (gate, barrier))
# Load from local memory (after the IF/barrier)
local_load = UOp(Ops.LOAD, dtypes.uint, (sbuf.index(lidx, ptr=True), if_barrier))
# Store to global memory
global_store = UOp(Ops.STORE, dtypes.void, (gbuf.index(gidx), local_load))
to_uops_list([global_store])
@unittest.skip("Bool load is not supported yet")
def test_load_mask(self):
with Context(IGNORE_OOB=0):
glbl0 = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(16), (), 0)
mask = UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(16), (), 0)
ridx = UOp.range(20, 0)
ld0 = UOp(Ops.LOAD, dtypes.int, (glbl0.index(UOp.const(ridx, ridx<16&mask), ptr=True)))
to_uops_list([ld0])
if __name__ == "__main__":
unittest.main()

View File

@@ -166,8 +166,10 @@ def get_program(ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> Program
if ast.arg is None: ast = ast.replace(arg=KernelInfo())
# rewrite to prg
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
if ast.op is Ops.PROGRAM: prg = ast
else:
full_sink = full_rewrite_to_sink(ast, renderer, optimize=ast.tag is None)
prg = UOp(Ops.PROGRAM, src=(full_sink, UOp(Ops.DEVICE, arg=renderer.device)))
prg = graph_rewrite(prg, pm_to_program, ctx=renderer, name="linearize/render")
# create the ProgramSpec

View File

@@ -45,6 +45,7 @@ class Scheduler:
ret = Scheduler(self.ast, self.ren)
ret.dont_use_locals = self.dont_use_locals
ret.applied_opts = self.applied_opts[:]
if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core
return ret
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
@@ -307,6 +308,7 @@ class Scheduler:
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
self.ast = self.ast.substitute({reduceop: tc_uop})
self.tensor_core = tc
return axes
return None

View File

@@ -93,8 +93,8 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
# *** external API ***
# get dictionary of all possible actions
def get_kernel_actions(s:Scheduler, include_0=True) -> dict[int, Scheduler]:
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]:
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024)
kernel_actions = actions.copy()
for i,a in enumerate(kernel_actions):

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from typing import Final, ClassVar, Callable, Literal
import math, struct, ctypes, functools
from dataclasses import dataclass, fields
from tinygrad.helpers import getenv, prod
from tinygrad.helpers import getenv, prod, round_up, next_power2
from enum import Enum, auto
class InvalidTypeMetaClass(type):
@@ -101,6 +101,15 @@ class ImageDType(PtrDType):
assert addrspace == AddrSpace.GLOBAL, "images can't be local"
return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" + (f'.vec({self.v})' if self.v != 1 else '')
@property
def pitch(self):
imgw, imgh, itemsize_log = self.shape[1], self.shape[0], int(math.log2(self.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
granularity = 128 if self.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
return round_up(imgw * 4 * self.itemsize, 1 << pitchalign) + pitch_add
class dtypes:
@staticmethod

View File

@@ -125,7 +125,7 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat((Ops.SINK, Ops.PROGRAM), name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \

View File

@@ -38,18 +38,18 @@ def ansilen(s:str): return len(ansistrip(s))
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l):
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
if hasattr(l, "shape") and l.shape == (): return [l[()]]
flattened = []
for li in l: flattened.extend(fully_flatten(li))
return flattened
return [l]
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fst[-1] == ')' and _is_balanced(fst[1:-1]) else fst
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
def ceildiv(num, amt):
# use (num + amt - 1) // amt when num is a UOp and non-negative to avoid C/Python division mismatch
if hasattr(num, 'vmin') and num.vmin >= 0 and (amt > 0 if isinstance(amt, int) else amt.vmin > 0): return (num + amt - 1) // amt
return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
# cstyle div and mod
def cdiv(x:int, y:int) -> int: return abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0
def cmod(x:int, y:int) -> int: return x-cdiv(x,y)*y
@@ -87,9 +87,7 @@ def word_wrap(x, wrap=80):
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
return x[:i] + "\n" + word_wrap(x[i:], wrap)
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
def panic(e:Exception|None=None):
if e is None: raise RuntimeError("PANIC!")
raise e
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
@functools.cache
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
@@ -149,9 +147,7 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
def stderr_log(msg):
sys.stderr.write(msg)
sys.stderr.flush()
def stderr_log(msg:str): print(msg, end='', file=sys.stderr, flush=True)
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs
@@ -512,7 +508,9 @@ class tqdm(Generic[T]):
if elapsed and self.i/elapsed > self.rate and self.i: self.skip = max(int(self.i/elapsed)//self.rate,1)
def HMS(t): return ':'.join(f'{x:02d}' if i else str(x) for i,x in enumerate([int(t)//3600,int(t)%3600//60,int(t)%60]) if i or x)
def SI(x):
return (f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)].strip()) if x else '0.00'
if not x: return '0.00'
v = f"{x/1000**int(g:=round(math.log(x,1000),6)):.{int(3-3*math.fmod(g,1))}f}"[:4].rstrip('.')
return (f"{x/1000**(int(g)+1):.3f}"[:4].rstrip('.')+' kMGTPEZY'[int(g)+1]) if v == "1000" else v+' kMGTPEZY'[int(g)].strip()
prog_text = f'{SI(self.n)}{f"/{SI(self.t)}" if self.t else self.unit}' if self.unit_scale else f'{self.n}{f"/{self.t}" if self.t else self.unit}'
est_text = f'<{HMS(elapsed/prog-elapsed) if self.n else "?"}' if self.t else ''
it_text = (SI(self.n/elapsed) if self.unit_scale else f"{self.n/elapsed:5.2f}") if self.n else "?"

View File

@@ -133,7 +133,7 @@ string_rewrite = PatternMatcher([
(UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"),
(UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"),
(UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))),
(UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier),
(UPat(Ops.BARRIER), lambda ctx: ctx.barrier),
(UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"),
])
@@ -180,7 +180,7 @@ class PTXRenderer(Renderer):
self.uops = uops
def ssa(prefix:str, u:UOp|None=None, dtype:str|None=None) -> str:
nonlocal c, r
nonlocal c
prefix += f"_{dtype if dtype is not None else self.types[unwrap(u).dtype.base]}_"
c[prefix] += 1
return f"%{prefix}{c[prefix]-1}"
@@ -230,7 +230,7 @@ class PTXRenderer(Renderer):
[ssa("wmma_acc", dtype="b32") for _ in range(0, len(r[u.src[2]]), 4 // u.dtype.scalar().itemsize)]]
r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)]
prefix, dtype = {Ops.CAST: ("cast", None), Ops.BITCAST: ("cast", None), Ops.END: ("pred", "pred"), Ops.RANGE: ("ridx", None),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local",self.types[dtypes.ulong]),
Ops.DEFINE_VAR: ("dat", None), Ops.CONST: ("const", None), Ops.DEFINE_LOCAL: ("local", self.types[dtypes.ulong]),
Ops.DEFINE_GLOBAL: ("dat", self.types[dtypes.ulong]), **{op: ("alu", None) for op in GroupOp.ALU}}.get(u.op, (None, None))
if prefix: r[u] = ssa(prefix, u, dtype)

View File

@@ -22,7 +22,7 @@ class HCQGraph(MultiGraphRunner):
for (j,i), input_idx in self.input_replace.items():
x = self.input_replace_to_var.setdefault((j,i), UOp.variable(f"input_{input_idx}", 0, 0xffffffffffffffff, dtype=dtypes.uint64))
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, texture_info=self.hcq_bufs[j][i].texture_info) # Create fake buffer with variable
self.hcq_bufs[j][i] = HCQBuffer(x, self.hcq_bufs[j][i].size, image=self.hcq_bufs[j][i].image) # Create fake buffer with variable
# Allocate kernel args.
kernargs_size: dict[Compiled, int] = collections.defaultdict(int)

View File

@@ -834,11 +834,11 @@ class PCIIface(PCIIfaceBase):
if queue_type == kfd.KFD_IOC_QUEUE_TYPE_SDMA:
assert idx <= 3, "only 4 SDMA queues supported in am"
pv = self.dev_impl.sdma.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0 + idx * 0xA * 4), pipe=0, queue=idx)
pv, doorbell_index = self.dev_impl.sdma.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr,
wptr_addr=gart.va_addr+wptr, pipe=0, queue=idx)
else:
pv = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr, wptr_addr=gart.va_addr+wptr,
eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, doorbell=(doorbell_index:=am.AMDGPU_NAVI10_DOORBELL_MEC_RING0), pipe=0,
pv, doorbell_index = self.dev_impl.gfx.setup_ring(ring_addr=ring.va_addr, ring_size=ring.size, rptr_addr=gart.va_addr+rptr,
wptr_addr=gart.va_addr+wptr, eop_addr=eop_buffer.va_addr, eop_size=eop_buffer.size, pipe=0,
queue=int(is_aql:=(queue_type==kfd.KFD_IOC_QUEUE_TYPE_COMPUTE_AQL)), aql=is_aql)
return AMDQueueDesc(ring=ring.cpu_view().view(fmt='I'), doorbells=[self.dev_impl.doorbell64.view(doorbell_index * 8, 8, fmt='Q')],

View File

@@ -10,7 +10,7 @@ from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer
from tinygrad.renderer.nir import IR3Renderer
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport, cpu_profile, lo32, PROFILE, suppress_finalizing
from tinygrad.helpers import flatten, QCOM_IR3, QCOM_CC
from tinygrad.helpers import next_power2, flatten, QCOM_IR3, QCOM_CC
from tinygrad.runtime.support.system import System
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
@@ -25,7 +25,7 @@ def _qreg_exec(__reg, __val=0, **kwargs):
return __val
qreg: Any = type("QREG", (object,), {name[4:].lower(): functools.partial(_qreg_exec, name) for name in mesa.__dict__.keys() if name[:4] == 'REG_'})
def next_power2(x): return 1 if x == 0 else 1 << (x - 1).bit_length()
def ctz(v): return (v & -v).bit_length() - 1
def parity(val: int):
for i in range(4,1,-1): val ^= val >> (1 << i)
@@ -192,9 +192,8 @@ class QCOMArgsState(HCQArgsState):
super().__init__(buf, prg, bufs, vals=vals)
ctypes.memset(cast(int, self.buf.va_addr), 0, prg.kernargs_alloc_size)
ubos, uavs = [b for b in bufs if b.texture_info is None], [b for b in bufs if b.texture_info is not None]
ubos, uavs = [b for b in bufs if b.image is None], [b for b in bufs if b.image is not None]
ibos, texs = (uavs, []) if prg.tex_cnt == 0 else (uavs[:-prg.tex_cnt], uavs[-prg.tex_cnt:])
for cnst_val,cnst_off,cnst_sz in prg.consts_info: to_mv(self.buf.va_addr + cnst_off, cnst_sz)[:] = cnst_val.to_bytes(cnst_sz, byteorder='little')
if prg.samp_cnt > 0: to_mv(self.buf.va_addr + prg.samp_off, len(prg.samplers) * 4).cast('I')[:] = array.array('I', prg.samplers)
@@ -205,8 +204,15 @@ class QCOMArgsState(HCQArgsState):
for i, b in enumerate(ubos): self.bind_sints_to_buf(b.va_addr, buf=self.buf, fmt='Q', offset=prg.buf_offs[i])
for i, v in enumerate(vals): self.bind_sints_to_buf(v, buf=self.buf, fmt='I', offset=prg.buf_offs[i+len(ubos)])
self.bind_sints_to_buf(*flatten([b.texture_info.desc + ([0] * 8) for b in texs]), buf=self.buf, fmt='I', offset=prg.tex_off)
self.bind_sints_to_buf(*flatten([b.texture_info.ibo + ([0] * 8) for b in ibos]), buf=self.buf, fmt='I', offset=prg.ibo_off)
def _tex(b, ibo=False):
fmt = mesa.FMT6_32_32_32_32_FLOAT if b.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
return [qreg.a6xx_tex_const_0(fmt=fmt) if ibo else qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=fmt),
qreg.a6xx_tex_const_1(width=b.image.shape[1], height=b.image.shape[0]),
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=b.image.pitch, pitchalign=ctz(b.image.pitch)-6), 0, *data64_le(b.va_addr),
qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13), 0, 0, 0, 0, 0, 0, 0, 0]
self.bind_sints_to_buf(*flatten(map(_tex, texs)), buf=self.buf, fmt='I', offset=prg.tex_off)
self.bind_sints_to_buf(*flatten(map(functools.partial(_tex, ibo=True), ibos)), buf=self.buf, fmt='I', offset=prg.ibo_off)
class QCOMProgram(HCQProgram):
def __init__(self, dev: QCOMDevice, name: str, lib: bytes):
@@ -305,28 +311,10 @@ class QCOMTextureInfo:
self.pitch, self.real_stride, self.desc, self.ibo = pitch, real_stride, desc, ibo
class QCOMAllocator(HCQAllocatorBase):
def _alloc(self, size:int, options:BufferSpec) -> HCQBuffer:
def _alloc(self, size:int, opts:BufferSpec) -> HCQBuffer:
# Recalculate real size for texture
if options.image is not None:
imgw, imgh, itemsize_log = options.image.shape[1], options.image.shape[0], int(math.log2(options.image.itemsize))
pitchalign = max(6, 11 - int(math.log2(imgh))) if imgh > 1 else 6
align_up = max(1, (8 // itemsize_log + 1) - imgh // 32) if pitchalign == 6 else (2 ** (pitchalign - itemsize_log - 2))
granularity = 128 if options.image.itemsize == 4 else 256
pitch_add = (1 << pitchalign) if min(next_power2(imgw), round_up(imgw, granularity)) - align_up + 1 <= imgw and imgw > granularity//2 else 0
pitch = round_up((real_stride:=imgw * 4 * options.image.itemsize), 1 << pitchalign) + pitch_add
size = pitch * imgh
buf = self.dev._gpu_map(options.external_ptr, size) if options.external_ptr else self.dev._gpu_alloc(size)
if options.image is not None:
tex_fmt = mesa.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else mesa.FMT6_16_16_16_16_FLOAT
desc = [qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt), qreg.a6xx_tex_const_1(width=imgw, height=imgh),
qreg.a6xx_tex_const_2(type=mesa.A6XX_TEX_2D, pitch=pitch, pitchalign=pitchalign-6), 0,
*data64_le(buf.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
buf.texture_info = QCOMTextureInfo(pitch, real_stride, desc, [desc[0] & (~0xffff), *desc[1:len(desc)]])
return buf
if opts.image is not None: size = opts.image.pitch* opts.image.shape[0]
return self.dev._gpu_map(opts.external_ptr, size, image=opts.image) if opts.external_ptr else self.dev._gpu_alloc(size, image=opts.image)
def _do_copy(self, src_addr, dest_addr, src_size, real_size, src_stride, dest_stride, prof_text, dest_off=0, src_off=0):
with cpu_profile(prof_text, self.dev.device, is_copy=True):
@@ -335,13 +323,13 @@ class QCOMAllocator(HCQAllocatorBase):
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
def _copyin(self, dest:HCQBuffer, src:memoryview):
stride, pitch = (src.nbytes, src.nbytes) if (ti:=cast(QCOMTextureInfo, dest.texture_info)) is None else (ti.real_stride, ti.pitch)
stride, pitch = (dest.image.shape[1] * 4 * dest.image.itemsize, dest.image.pitch) if dest.image else (src.nbytes, src.nbytes)
self._do_copy(mv_address(src), dest.cpu_view().addr, src.nbytes, stride, stride, pitch, f"TINY -> {self.dev.device}")
def _copyout(self, dest:memoryview, src:HCQBuffer):
self.dev.synchronize()
stride, pitch = (src.size, src.size) if (ti:=cast(QCOMTextureInfo, src.texture_info)) is None else (ti.real_stride, ti.pitch)
stride, pitch = (src.image.shape[1] * 4 * src.image.itemsize, src.image.pitch) if src.image else (src.size, src.size)
self._do_copy(src.cpu_view().addr, mv_address(dest), src.size, stride, pitch, stride, f"{self.dev.device} -> TINY")
def _as_buffer(self, src:HCQBuffer) -> memoryview:
@@ -388,7 +376,7 @@ class QCOMDevice(HCQCompiled):
super().__init__(device, QCOMAllocator(self), compilers, functools.partial(QCOMProgram, self), QCOMSignal,
functools.partial(QCOMComputeQueue, self), None)
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False) -> HCQBuffer:
def _gpu_alloc(self, size:int, flags:int=0, uncached=False, fill_zeroes=False, **kwargs) -> HCQBuffer:
flags |= flag("KGSL_MEMALIGN", alignment_hint:=12) | kgsl.KGSL_MEMFLAGS_USE_CPU_MAP
if uncached: flags |= flag("KGSL_CACHEMODE", kgsl.KGSL_CACHEMODE_UNCACHED)
@@ -396,15 +384,15 @@ class QCOMDevice(HCQCompiled):
va_addr = self.fd.mmap(0, bosz, mmap.PROT_READ | mmap.PROT_WRITE, mmap.MAP_SHARED, alloc.id * 0x1000)
if fill_zeroes: ctypes.memset(va_addr, 0, size)
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self)
return HCQBuffer(va_addr=va_addr, size=size, meta=(alloc, True), view=MMIOInterface(va_addr, size, fmt='B'), owner=self, **kwargs)
def _gpu_map(self, ptr:int, size:int) -> HCQBuffer:
def _gpu_map(self, ptr:int, size:int, **kwargs) -> HCQBuffer:
ptr_aligned, size_aligned = (ptr & ~0xfff), round_up(size + (ptr & 0xfff), 0x1000)
try:
mapinfo = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
return HCQBuffer(mapinfo.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mapinfo, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
mi = kgsl.IOCTL_KGSL_MAP_USER_MEM(self.fd, hostptr=ptr_aligned, len=size_aligned, memtype=kgsl.KGSL_USER_MEM_TYPE_ADDR)
return HCQBuffer(mi.gpuaddr + (ptr - ptr_aligned), size=size, meta=(mi, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
except OSError as e:
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self)
if e.errno == 14: return HCQBuffer(va_addr=ptr, size=size, meta=(None, False), view=MMIOInterface(ptr, size, fmt='B'), owner=self, **kwargs)
raise RuntimeError("Failed to map external pointer to GPU memory") from e
def _gpu_free(self, mem:HCQBuffer):

View File

@@ -281,9 +281,10 @@ class AM_GFX(AM_IP):
self._grbm_select(inst=xcc)
for xcc in range(self.xccs): self.adev.regGCVM_CONTEXT0_CNTL.write(0, inst=xcc)
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, doorbell:int, pipe:int, queue:int,
aql:bool) -> int:
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, eop_addr:int, eop_size:int, pipe:int, queue:int,
aql:bool) -> tuple[int, int]:
self._grbm_select(me=1, pipe=pipe, queue=queue, inst=0)
doorbell = am.AMDGPU_NAVI10_DOORBELL_MEC_RING0
restore_queue = aql and self.xccs > 1 and self.adev.partial_boot and (self.adev.regCP_HQD_ACTIVE.read(inst=0) & 1)
restore_ptr = (self.adev.regCP_HQD_PQ_WPTR_LO.read(inst=0) | (self.adev.regCP_HQD_PQ_WPTR_HI.read(inst=0) << 32)) if restore_queue else 0
if DEBUG >= 2 and restore_queue: print(f"am {self.adev.devfmt}: GFX queue already active, continuing from saved state {restore_ptr=:#x}.")
@@ -327,7 +328,7 @@ class AM_GFX(AM_IP):
self._grbm_select(inst=xcc)
self.adev.reg(f"regCP_ME1_PIPE{pipe}_INT_CNTL").update(time_stamp_int_enable=1, generic0_int_enable=1, inst=xcc)
return restore_ptr // 16
return restore_ptr // 16, doorbell
def set_clockgating_state(self):
if hasattr(self.adev, 'regMM_ATC_L2_MISC_CG'): self.adev.regMM_ATC_L2_MISC_CG.write(enable=1, mem_ls_enable=1)
@@ -447,9 +448,9 @@ class AM_SDMA(AM_IP):
time.sleep(0.01)
self.adev.regGRBM_SOFT_RESET.write(0x0)
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, doorbell:int, pipe:int, queue:int) -> int:
# Setup the ring
def setup_ring(self, ring_addr:int, ring_size:int, rptr_addr:int, wptr_addr:int, pipe:int, queue:int) -> tuple[int, int]:
reg, inst = ("regSDMA_GFX", pipe+queue*4) if self.adev.ip_ver[am.SDMA0_HWIP][:2] == (4,4) else (f"regSDMA{pipe}_QUEUE{queue}", 0)
doorbell = am.AMDGPU_NAVI10_DOORBELL_sDMA_ENGINE0 + (pipe+queue*4) * 0xA
self.sdma_reginst.append((reg, inst))
self.adev.reg(f"{reg}_MINOR_PTR_UPDATE").write(0x1, inst=inst)
@@ -464,7 +465,7 @@ class AM_SDMA(AM_IP):
self.adev.reg(f"{reg}_RB_CNTL").write(**({f'{self.sdma_name.lower()}_wptr_poll_enable':1} if self.adev.ip_ver[am.SDMA0_HWIP][:2]!=(4,4) else {}),
rb_vmid=0, rptr_writeback_enable=1, rptr_writeback_timer=4, rb_enable=1, rb_priv=1, rb_size=(ring_size//4).bit_length()-1, inst=inst)
self.adev.reg(f"{reg}_IB_CNTL").update(ib_enable=1, inst=inst)
return self.adev.reg(f"{reg}_RB_WPTR").read(inst=inst) | (self.adev.reg(f"{reg}_RB_WPTR_HI").read(inst=inst) << 32)
return self.adev.reg(f"{reg}_RB_WPTR").read(inst=inst) | (self.adev.reg(f"{reg}_RB_WPTR_HI").read(inst=inst) << 32), doorbell
class AM_PSP(AM_IP):
def init_sw(self):

View File

@@ -8,6 +8,7 @@ from tinygrad.device import BufferSpec, Compiled, LRUAllocator, ProfileDeviceEve
from tinygrad.uop.ops import sym_infer, sint, UOp
from tinygrad.runtime.autogen import libc
from tinygrad.runtime.support.memory import BumpAllocator
from tinygrad.dtype import ImageDType
class MMIOInterface:
def __init__(self, addr:int, nbytes:int, fmt='B'): self.mv, self.addr, self.nbytes, self.fmt = to_mv(addr, nbytes).cast(fmt), addr, nbytes, fmt
@@ -455,14 +456,14 @@ class HCQCompiled(Compiled, Generic[SignalType]):
if hasattr(self, 'iface') and hasattr(self.iface, 'device_fini'): self.iface.device_fini()
class HCQBuffer:
def __init__(self, va_addr:sint, size:int, texture_info:Any=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
def __init__(self, va_addr:sint, size:int, image:ImageDType|None=None, meta:Any=None, _base:HCQBuffer|None=None, view:MMIOInterface|None=None,
owner:HCQCompiled|None=None):
self.va_addr, self.size, self.texture_info, self.meta, self._base, self.view = va_addr, size, texture_info, meta, _base, view
self.va_addr, self.size, self.image, self.meta, self._base, self.view = va_addr, size, image, meta, _base, view
self._devs, self.owner = ([owner] if owner is not None else []), owner
self._mappings:dict[HCQCompiled, HCQBuffer] = {} # mapping to the other devices
def offset(self, offset:int=0, size:int|None=None) -> HCQBuffer:
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, texture_info=self.texture_info, meta=self.meta,
return HCQBuffer(self.va_addr+offset, size or (self.size - offset), owner=self.owner, image=self.image, meta=self.meta,
_base=self._base or self, view=(self.view.view(offset=offset, size=size) if self.view is not None else None))
def cpu_view(self) -> MMIOInterface:

View File

@@ -127,7 +127,7 @@ class Tensor(OpMixin):
# create a UOp from the different types of inputs
if isinstance(data, UOp):
assert _dtype is None or _dtype==data.dtype, f"dtype doesn't match ({_dtype} vs {data.dtype}), and casting isn't supported"
assert _dtype is None or _dtype==data.dtype or data.dtype==dtypes.index, f"dtype mismatch: {_dtype} vs {data.dtype}"
# if data is dtype.index that means that this is a symbolic int and we need to lower it to something we can make a Tensor out of
if data.dtype==dtypes.index: data = _index_to_concrete_int(data)
if data.op is Ops.BIND: # type: ignore # mypy type narrowing is bugged here
@@ -3637,11 +3637,13 @@ class Tensor(OpMixin):
Q = Tensor.eye(m, dtype=self.dtype).reshape((1,) * len(b_shape) + (m, m)).expand(b_shape + (m, m)).contiguous()
for i in range(min(m, n)):
x = R[..., i:m, i].contiguous() # TODO: without contigous this can silently be wrong, should at least assert
s = -x[..., 0].sign()
u1 = x[..., 0] - s * x.square().sum(-1).sqrt()
w = x.unsqueeze(-1) / u1.reshape(b_shape + (1, 1))
norm = x.square().sum(-1).sqrt()
s = (x[..., 0] != 0).where(-x[..., 0].sign(), -1)
u1 = x[..., 0] - s * norm
w = x.unsqueeze(-1) / (norm != 0).where(u1, 1).reshape(b_shape + (1, 1))
w[..., 0, 0] = 1
tau = (-s * u1 / x.square().sum(-1).sqrt()).reshape(b_shape + (1, 1))
tau = (-s * u1 / (norm != 0).where(norm, 1)).reshape(b_shape + (1, 1))
tau = (norm != 0).reshape(b_shape + (1, 1)).where(tau, 0)
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
return Q,R
@@ -3668,8 +3670,10 @@ class Tensor(OpMixin):
#compute the jacobi rotations for each pairing
gamma = (U_left * U_right).sum(-2).reshape(b_shape + (1, num//2))
alpha, beta = U_permuted.square().sum(-2).unsqueeze(-2).split(num//2, -1)
tau = (beta - alpha) / (2 * gamma)
t = tau.sign() / (tau.abs() + (1 + tau.square()).sqrt())
rot = gamma != 0
tau = (beta - alpha) / (2 * rot.where(gamma, 1))
t = (tau != 0).where(tau.sign(), 1) / (tau.abs() + (1 + tau.square()).sqrt())
t = rot.where(t, 0)
c = 1 / (1 + t.square()).sqrt()
s = c * t
#apply the rotations
@@ -3686,9 +3690,9 @@ class Tensor(OpMixin):
for _ in range(max_iterations * iterations_per_round): U, V, permute, inverse_permute = one_round_jacobi(U, V, permute, inverse_permute)
#extract singular values and sort. construct U from Q
S, indices = U.square().sum(-2).sqrt().sort(dim = -1, descending=True)
new_indices = Tensor.arange(num).reshape((1,) * (self.ndim - 1) + (num,)).expand(b_shape + (num, num)).contiguous()
new_indices[..., :num] = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U, V = U.gather(-1, new_indices[...,0:num,0:num]) / S.unsqueeze(-2), V.gather(-1, new_indices[..., 0:num, 0:num]).realize()
new_indices = indices.reshape(b_shape + (1, num)).expand(b_shape + (num, num))
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices).realize()
padded_u = Tensor.eye(q_num, dtype=U.dtype).reshape((1,) * len(b_shape) + (q_num, q_num)).expand(b_shape + (q_num, q_num)).contiguous()
padded_u[..., 0:num, 0:num] = U

View File

@@ -223,26 +223,26 @@ def xlog2(d:UOp) -> UOp:
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype.scalar() in TRANSCENDENTAL_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype.scalar() == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
FLT_MIN = d.const_like(1e-6 if d.dtype.scalar() == dtypes.float16 else 1e-4)
# float16 uses 2^10 for denormal scaling (2^64 overflows), float32/64 use 2^64
denormal_exp = 10 if d.dtype.scalar() == dtypes.float16 else 64
FLT_MIN = d.const_like({dtypes.float16: 6.1e-5, dtypes.float32: 1e-4, dtypes.float64: 1e-4}[d.dtype.scalar()])
is_denormal = d<FLT_MIN
a = is_denormal.where(d * (2 ** 64), d)
a = is_denormal.where(d * (2 ** denormal_exp), d)
e = ilogb2k(a * (1.0 / 0.75)).cast(a.dtype)
m = ldexp3k(a, -e)
e = is_denormal.where(e - 64, e)
e = is_denormal.where(e - denormal_exp, e)
x = (m - 1.0) / (m + 1.0)
x2 = x * x
if d.dtype.scalar() == dtypes.float64:
t = polyN(x2, [0.2211941750456081490e+0, 0.2200768693152277689e+0, 0.2623708057488514656e+0, 0.3205977477944495502e+0,
0.4121985945485324709e+0, 0.5770780162997058982e+0, 0.96179669392608091449])
s_hi, s_lo = e+x*2.885390081777926774, e.const_like(0)
r = t * (x * x2) + e + x * 2.885390081777926774
else:
t = polyN(x2, [0.4374550283e+0, 0.5764790177e+0, 0.9618012905120])
s_hi, s_lo = e+x*2.8853900432586669922, x*3.2734474483568488616e-08
r = t * (x * x2) + (s_hi + s_lo)
# s_lo term (x*3.27e-08) only for float32 - underflows in float16
r = t * (x * x2) + e + x * 2.8853900432586669922 + (x * 3.2734474483568488616e-08 if d.dtype.scalar() == dtypes.float32 else 0)
# log2(Inf) = Inf
r = d.ne(math.inf).where(r, r.const_like(math.inf))

View File

@@ -11,9 +11,8 @@ try:
# IDIV is truncated division but z3 does euclidian division (floor if b>0 ceil otherwise); mod by power of two sometimes uses Ops.AND
def z3_cdiv(a, b):return z3.If((a<0), z3.If(0<b, (a+(b-1))/b, (a-(b+1))/b), a/b)
def z3_xor(a,b):
if isinstance(a, z3.BoolRef): return a^b
assert a==-1 or b==-1, "xor can only be used in indexing if one of the arguments is -1"
return -a-1 if b==-1 else -b-1
assert isinstance(a, z3.BoolRef), f"{type(a)=}, {a=}"
return a^b
z3_alu: dict[Ops, Callable] = python_alu | {Ops.MOD: lambda a,b: a-z3_cdiv(a,b)*b, Ops.IDIV: z3_cdiv, Ops.SHR: lambda a,b: a/(2**b.as_long()),
Ops.SHL: lambda a,b: a*(2**b.as_long()), Ops.AND: lambda a,b: a%(b+1) if isinstance(b, z3.ArithRef) else a&b, Ops.WHERE: z3.If, Ops.XOR: z3_xor,
Ops.MAX: lambda a,b: z3.If(a<b, b, a),}
@@ -34,7 +33,6 @@ try:
(UPat(Ops.CONST, dtypes.ints+(dtypes.index,), name="x"), lambda x,ctx: (z3.IntVal(x.arg, ctx=ctx[0].ctx), None)),
(UPat(Ops.CONST, dtypes.bool, name="x"), lambda x,ctx: (z3.BoolVal(x.arg, ctx=ctx[0].ctx), None)),
# casts from floats create new variables
(UPat(Ops.CAST, dtypes.bool, src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx: (z3.Bool(f"cast{len(ctx[1])}",ctx=ctx[0].ctx), None)),
(UPat(Ops.CAST, dtypes.ints+(dtypes.index,), src=(UPat(dtype=dtypes.floats),), name="x"), lambda x,ctx:
create_bounded(f"cast{len(ctx[1])}", x.dtype.min, x.dtype.max, ctx[0])),
# A comparison between floats introduces a new bool variable
@@ -67,11 +65,10 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
# We can use UOp min/max to do a faster check, but it can give false positive since its not an exact bound and doesn't consider the mask
if 0<=idx.vmin and idx.vmax<sz: return True
# WEBGPU has a BITCAST in the index. TODO: fix
if any(x.op is Ops.BITCAST for x in idx.toposort() | gate.toposort()): return True
# PTX uses absolute addresses (pointer cast to long), skip validation
if any(x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType) for x in idx.toposort()): return True
# TODO: validate these
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
for x in idx.toposort() | gate.toposort():
if x.op is Ops.BITCAST or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
if not z3_imported: raise ImportError("bounds checking requires z3 >= 4.12.4, use IGNORE_OOB=1 to disable, or \"pip install 'z3-solver>=4.12.4\"")
solver = z3.Solver(ctx=z3.Context())