mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
unify cpu and llvm (#11982)
* try unify cpu and llvm * fixes * fix * ops * no llvm * fix * rm * lvmm is ot * oops * override * no llvm * ignore * skip llvm * ooops
This commit is contained in:
10
.github/workflows/benchmark.yml
vendored
10
.github/workflows/benchmark.yml
vendored
@@ -68,10 +68,10 @@ jobs:
|
||||
run: METAL=1 python3.11 test/opt/test_tensor_cores.py
|
||||
- name: Test AMX tensor cores
|
||||
run: |
|
||||
DEBUG=2 CPU=1 AMX=1 python3.11 test/opt/test_tensor_cores.py
|
||||
DEBUG=2 LLVM=1 AMX=1 python3.11 test/opt/test_tensor_cores.py
|
||||
DEBUG=2 CPU=1 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
DEBUG=2 LLVM=1 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
DEBUG=2 CPU=1 CPU_LLVM=0 AMX=1 python3.11 test/opt/test_tensor_cores.py
|
||||
DEBUG=2 CPU=1 CPU_LLVM=1 AMX=1 python3.11 test/opt/test_tensor_cores.py
|
||||
DEBUG=2 CPU=1 CPU_LLVM=0 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
DEBUG=2 CPU=1 CPU_LLVM=1 AMX=1 python3.11 test/opt/test_gen_float4.py TestFloat4.test_float4_multidim_amx TestFloat4.test_float4_multidim_unaligned_load_amx
|
||||
- name: Run Tensor Core GEMM (float)
|
||||
run: DEBUG=2 SHOULD_USE_TC=1 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
- name: Run Tensor Core GEMM (half)
|
||||
@@ -626,7 +626,7 @@ jobs:
|
||||
# generate quantized weights
|
||||
ln -s /data/home/tiny/tinygrad/extra/datasets/imagenet extra/datasets/imagenet
|
||||
ln -s /data/home/tiny/tinygrad/testsig-*.so .
|
||||
PYTHONPATH=. CC=clang-19 CPU=1 QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
PYTHONPATH=. CC=clang-19 CPU=1 CPU_LLVM=0 QUANT=1 CNT=0 python3 examples/test_onnx_imagenet.py https://github.com/xamcat/mobcat-samples/raw/refs/heads/master/onnx_runtime/InferencingSample/InferencingSample/mobilenetv2-7.onnx /tmp/model.quant.onnx
|
||||
# benchmark on DSP with NOOPT=1, the devectorizer has issues
|
||||
PYTHONPATH=. CC=clang-19 DSP=1 DONT_REALIZE_EXPAND=1 NOOPT=1 CNT=2 DEBUG=2 python3 examples/test_onnx_imagenet.py /tmp/model.quant.onnx
|
||||
- name: Run process replay tests
|
||||
|
||||
52
.github/workflows/test.yml
vendored
52
.github/workflows/test.yml
vendored
@@ -33,9 +33,9 @@ jobs:
|
||||
- name: External Benchmark Schedule
|
||||
run: python3 test/external/external_benchmark_schedule.py
|
||||
- name: Speed Test
|
||||
run: LLVM=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
run: CPU=1 CPU_LLVM=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
- name: Speed Test (BEAM=2)
|
||||
run: BEAM=2 LLVM=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
run: BEAM=2 CPU=1 CPU_LLVM=1 python3 test/speed/external_test_speed_v_torch.py
|
||||
|
||||
docs:
|
||||
name: Docs
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
- name: Compile EfficientNet to C and test it
|
||||
run: |
|
||||
CPU=1 python examples/compile_efficientnet.py > recognize.c
|
||||
CPU=1 CPU_LLVM=0 python examples/compile_efficientnet.py > recognize.c
|
||||
clang -O2 recognize.c -lm -o recognize
|
||||
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
|
||||
|
||||
@@ -122,11 +122,11 @@ jobs:
|
||||
- name: Test one op in torch tests
|
||||
run: DEBUG=2 python3 extra/torch_backend/torch_tests.py TestTinyBackendPRIVATEUSE1.test_unary_log_tiny_float32
|
||||
- name: Test Ops with TINY_BACKEND
|
||||
run: LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20
|
||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=0 TINY_BACKEND=1 python3 -m pytest -n auto test/test_ops.py --durations=20
|
||||
- name: Test in-place operations on views
|
||||
run: TORCH_DEBUG=1 python3 extra/torch_backend/test_inplace.py
|
||||
- name: Test multi-gpu
|
||||
run: LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
|
||||
run: CPU=1 CPU_LLVM=1 GPUS=4 TORCH_DEBUG=1 python3 extra/torch_backend/test_multigpu.py
|
||||
|
||||
torchbackendmore:
|
||||
name: Torch Backend Tests More
|
||||
@@ -148,7 +148,7 @@ jobs:
|
||||
sudo apt update || true
|
||||
sudo apt install -y --no-install-recommends ninja-build
|
||||
- name: Test beautiful_mnist in torch with TINY_BACKEND
|
||||
run: SPLIT_REDUCEOP=0 FUSE_ARANGE=1 LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
run: SPLIT_REDUCEOP=0 FUSE_ARANGE=1 CPU=1 CPU_LLVM=1 TARGET_EVAL_ACC_PCT=96.0 TINY_BACKEND=1 python3 examples/other_mnist/beautiful_mnist_torch.py
|
||||
- name: Test some torch tests (expect failure)
|
||||
run: python3 -m pytest extra/torch_backend/torch_tests.py -v --tb=no || true
|
||||
|
||||
@@ -405,7 +405,7 @@ jobs:
|
||||
# - name: Test openpilot simple_plan vision model correctness (float32)
|
||||
# run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
||||
- name: Test openpilot LLVM compile
|
||||
run: LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot compile4
|
||||
run: NOLOCALS=1 GPU=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
- name: Run process replay tests
|
||||
@@ -429,15 +429,15 @@ jobs:
|
||||
python-version: '3.11'
|
||||
llvm: 'true'
|
||||
- name: Test ONNX (CPU)
|
||||
run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
run: CPU=1 CPU_LLVM=0 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX (LLVM)
|
||||
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
run: CPU=1 CPU_LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX Runner (CPU)
|
||||
run: CPU=1 python3 test/external/external_test_onnx_runner.py
|
||||
run: CPU=1 CPU_LLVM=0 python3 test/external/external_test_onnx_runner.py
|
||||
- name: Test Additional ONNX Ops (CPU)
|
||||
run: CPU=1 python3 test/external/external_test_onnx_ops.py
|
||||
run: CPU=1 CPU_LLVM=0 python3 test/external/external_test_onnx_ops.py
|
||||
- name: Test Quantize ONNX
|
||||
run: CPU=1 python3 test/test_quantize_onnx.py
|
||||
run: CPU=1 CPU_LLVM=0 python3 test/test_quantize_onnx.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
@@ -505,11 +505,11 @@ jobs:
|
||||
opencl: 'true'
|
||||
llvm: 'true'
|
||||
- name: Test models (llvm)
|
||||
run: LLVM=1 python -m pytest -n=auto test/models --durations=20
|
||||
run: CPU=1 CPU_LLVM=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (gpu)
|
||||
run: GPU=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: CPU=1 python -m pytest -n=auto test/models --durations=20
|
||||
run: CPU=1 CPU_LLVM=0 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
@@ -531,15 +531,15 @@ jobs:
|
||||
# test_symbolic_arange_sym_step is passing now
|
||||
# test_threefry_doesnt_use_long is because there's a contig after the long now
|
||||
run: |
|
||||
CPU=1 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
||||
CPU=1 CPU_LLVM=0 RANGEIFY=1 python3 -m pytest -n auto --durations 20 \
|
||||
-k "not test_symbolic_arange_sym_step and not test_threefry_doesnt_use_long" \
|
||||
test/test_tiny.py test/test_rangeify.py test/test_ops.py test/test_tensor_variable.py \
|
||||
test/test_outerworld_range.py test/test_sample.py test/test_randomness.py
|
||||
- name: Test CPU=1 RANGEIFY=2
|
||||
run: CPU=1 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
run: CPU=1 CPU_LLVM=0 RANGEIFY=2 python3 -m pytest -n auto test/test_tiny.py test/test_rangeify.py test/test_ops.py --durations 20
|
||||
# slow (and still wrong on beautiful_mnist)
|
||||
#- name: Test LLVM=1 RANGEIFY=1 (slow tests)
|
||||
# run: LLVM=1 RANGEIFY=1 python3 -m pytest -n auto test/models/test_mnist.py --durations 20
|
||||
# run: CPU=1 CPU_LLVM=1 RANGEIFY=1 python3 -m pytest -n auto test/models/test_mnist.py --durations 20
|
||||
|
||||
testdevectorize:
|
||||
name: Linux (devectorize)
|
||||
@@ -558,11 +558,11 @@ jobs:
|
||||
pydeps: "pillow"
|
||||
llvm: "true"
|
||||
- name: Test LLVM=1 DEVECTORIZE=0
|
||||
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
run: CPU=1 CPU_LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
- name: Test LLVM=1 DEVECTORIZE=0 for model
|
||||
run: LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
|
||||
run: CPU=1 CPU_LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
|
||||
- name: Test CPU=1 DEVECTORIZE=0
|
||||
run: CPU=1 DEVECTORIZE=0 FUSE_ARANGE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
run: CPU=1 CPU_LLVM=0 DEVECTORIZE=0 FUSE_ARANGE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
|
||||
testdsp:
|
||||
name: Linux (DSP)
|
||||
@@ -728,10 +728,10 @@ jobs:
|
||||
opencl: ${{ matrix.backend == 'gpu' && 'true' }}
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CPU','GPU'], Device.DEFAULT"
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','GPU'], Device.DEFAULT"
|
||||
DEBUG=5 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: python -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
|
||||
@@ -952,10 +952,10 @@ jobs:
|
||||
pydeps: "capstone"
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'metal' && 'METAL=1'}}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
DEBUG=4 python3 test/test_tiny.py TestTiny.test_plus
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
|
||||
@@ -989,7 +989,7 @@ jobs:
|
||||
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
# test_newton_schulz hits RecursionError
|
||||
@@ -997,5 +997,5 @@ jobs:
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
shell: bash
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == {'LLVM':'CPU'}.get(x:='${{ matrix.backend }}'.upper(), x), Device.DEFAULT"
|
||||
python -m pytest -n=auto test/test_tiny.py test/test_ops.py --durations=20
|
||||
|
||||
@@ -80,8 +80,7 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
|
||||
tinygrad already supports numerous accelerators, including:
|
||||
|
||||
- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
|
||||
- [x] [CPU (C Code)](tinygrad/runtime/ops_cpu.py)
|
||||
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
||||
- [x] [CPU](tinygrad/runtime/ops_cpu.py)
|
||||
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
||||
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
||||
- [x] [AMD](tinygrad/runtime/ops_amd.py)
|
||||
|
||||
@@ -36,8 +36,7 @@ CUDA | [1] | enable CUDA backend
|
||||
AMD | [1] | enable AMD backend
|
||||
NV | [1] | enable NV backend
|
||||
METAL | [1] | enable Metal backend (for Mac M1 and after)
|
||||
CPU | [1] | enable CPU (Clang) backend
|
||||
LLVM | [1] | enable LLVM backend
|
||||
CPU | [1] | enable CPU backend
|
||||
BEAM | [#] | number of beams in kernel beam search
|
||||
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
|
||||
IMAGE | [1-2] | enable 2d specific optimizations
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||
TestHCQ.d0.timeline_value += 1
|
||||
|
||||
@unittest.skipIf(MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"}, "Can't handle async update on MOCKGPU for now")
|
||||
@unittest.skipIf(MOCKGPU or Device.DEFAULT in {"CPU"}, "Can't handle async update on MOCKGPU for now")
|
||||
def test_wait_late_set(self):
|
||||
for queue_type in [TestHCQ.d0.hw_compute_queue_t, TestHCQ.d0.hw_copy_queue_t]:
|
||||
if queue_type is None: continue
|
||||
@@ -137,7 +137,7 @@ class TestHCQ(unittest.TestCase):
|
||||
val = TestHCQ.a.uop.buffer.as_buffer().cast("f")[0]
|
||||
assert val == 200.0, f"got val {val}"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
|
||||
def test_exec_update(self):
|
||||
sint_global = (Variable("sint_global", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.global_size[1:])
|
||||
sint_local = (Variable("sint_local", 0, 0xffffffff, dtypes.uint32),) + tuple(TestHCQ.runner.p.local_size[1:])
|
||||
@@ -155,7 +155,7 @@ class TestHCQ(unittest.TestCase):
|
||||
val = TestHCQ.b.uop.buffer.as_buffer().cast("f")[1]
|
||||
assert val == 0.0, f"got val {val}, should not be updated"
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "No globals/locals on LLVM/CPU")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "No globals/locals on LLVM/CPU")
|
||||
def test_exec_update_fuzz(self):
|
||||
virt_val = Variable("sig_val", 0, 0xffffffff, dtypes.uint32)
|
||||
virt_local = [Variable(f"local_{i}", 0, 0xffffffff, dtypes.uint32) for i in range(3)]
|
||||
@@ -336,7 +336,7 @@ class TestHCQ(unittest.TestCase):
|
||||
et = float(sig_en.timestamp - sig_st.timestamp)
|
||||
|
||||
print(f"exec kernel time: {et:.2f} us")
|
||||
assert 0.1 <= et <= (100000 if MOCKGPU or Device.DEFAULT in {"CPU", "LLVM"} else 100)
|
||||
assert 0.1 <= et <= (100000 if MOCKGPU or Device.DEFAULT in {"CPU"} else 100)
|
||||
|
||||
def test_speed_copy_bandwidth(self):
|
||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU", "LLVM"}, "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
|
||||
from tinygrad.helpers import CI, fetch
|
||||
from tinygrad.helpers import CI, fetch, CPU_LLVM
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -16,7 +16,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc
|
||||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU", "LLVM"], "slow")
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU"], "slow")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
@@ -33,11 +33,11 @@ class TestWhisper(unittest.TestCase):
|
||||
def test_transcribe_file1(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_1), TRANSCRIPTION_1)
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI")
|
||||
@unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too many tests for CI")
|
||||
def test_transcribe_file2(self):
|
||||
self.assertEqual(transcribe_file(self.model, self.enc, TEST_FILE_2), TRANSCRIPTION_2)
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too many tests for CI")
|
||||
@unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too many tests for CI")
|
||||
def test_transcribe_batch12(self):
|
||||
waveforms = [load_file_waveform(TEST_FILE_1), load_file_waveform(TEST_FILE_2)]
|
||||
transcriptions = transcribe_waveform(self.model, self.enc, waveforms)
|
||||
@@ -52,13 +52,13 @@ class TestWhisper(unittest.TestCase):
|
||||
self.assertEqual(TRANSCRIPTION_2, transcriptions[0])
|
||||
self.assertEqual(TRANSCRIPTION_1, transcriptions[1])
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too long for CI")
|
||||
@unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too long for CI")
|
||||
def test_transcribe_long(self):
|
||||
waveform = [load_file_waveform(fetch(TEST_FILE_3_URL))]
|
||||
transcription = transcribe_waveform(self.model, self.enc, waveform)
|
||||
self.assertEqual(TRANSCRIPTION_3, transcription)
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT == "LLVM", "too long for CI")
|
||||
@unittest.skipIf(CI or (Device.DEFAULT == "CPU" and CPU_LLVM), "too long for CI")
|
||||
def test_transcribe_long_no_batch(self):
|
||||
waveforms = [load_file_waveform(fetch(TEST_FILE_3_URL)), load_file_waveform(TEST_FILE_1)]
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (2, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim(self):
|
||||
a = Tensor.empty(2, 8).realize()
|
||||
b = Tensor.empty(2, 8).realize()
|
||||
@@ -39,7 +39,7 @@ class TestFloat4(unittest.TestCase):
|
||||
uops = get_program(s.ast, opts=[Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=2)]).uops
|
||||
assert TestFloat4.count_float4(uops) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize()
|
||||
@@ -69,7 +69,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(program.uops) == (0, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load(self):
|
||||
a = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
b = Tensor.empty(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
@@ -80,7 +80,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(uops) == (0, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.empty(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.uop.ops import Ops
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import AMX, CI, AMD_LLVM
|
||||
from tinygrad.helpers import AMX, CI, AMD_LLVM, CPU_LLVM
|
||||
from tinygrad.engine.realize import CompiledRunner, get_program
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError
|
||||
|
||||
@@ -69,7 +69,7 @@ class TestTensorCores(unittest.TestCase):
|
||||
a, b = Tensor.rand(m, k, dtype=tc.dtype_in), Tensor.rand(k, n, dtype=tc.dtype_in)
|
||||
r = a.matmul(b, dtype=tc.dtype_out)
|
||||
prg = get_program(r.schedule()[-1].ast, opts=[Opt(op=OptOps.TC, axis=0, arg=(-1, 2, 1))])
|
||||
if Device.DEFAULT == "LLVM":
|
||||
if Device.DEFAULT == "CPU" and CPU_LLVM:
|
||||
assert "0x201000" in prg.src
|
||||
elif Device.DEFAULT == "AMD" and AMD_LLVM:
|
||||
assert "@llvm.amdgcn.wmma" in prg.src
|
||||
@@ -160,7 +160,7 @@ class TestTensorCores(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "slow on EMULATED device")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "CPU does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi(self):
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
@@ -174,7 +174,7 @@ class TestTensorCores(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "slow on EMULATED device")
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"}, "CPU does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi_with_children(self):
|
||||
# all STORE children are outside the loop
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, io
|
||||
from contextlib import redirect_stdout
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.helpers import OSX
|
||||
from tinygrad.helpers import OSX, CPU_LLVM
|
||||
from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import get_program
|
||||
@@ -19,7 +19,7 @@ class TestCompileFailures(unittest.TestCase):
|
||||
|
||||
class TestDisassembly(unittest.TestCase):
|
||||
# TODO: fails on llvm. llvm.LLVMGetHostCPUName() returns "generic"
|
||||
@unittest.skipUnless(Device.DEFAULT in ("CPU",) and OSX, "m series cpus support fp16 arithmetic")
|
||||
@unittest.skipUnless(Device.DEFAULT in ("CPU",) and not CPU_LLVM and OSX, "m series cpus support fp16 arithmetic")
|
||||
def test_float16_alu(self):
|
||||
c = Tensor([1], dtype=dtypes.float16) + Tensor([1], dtype=dtypes.float16)
|
||||
s = c.schedule()[-1]
|
||||
|
||||
@@ -23,7 +23,7 @@ def _simple_test(add, extract=lambda x: x, N=10):
|
||||
class TestJit(unittest.TestCase):
|
||||
|
||||
@settings(deadline=2e4)
|
||||
@unittest.skipUnless(REAL_DEV in ["LLVM", "CPU"], f"no support on {REAL_DEV}")
|
||||
@unittest.skipUnless(REAL_DEV in ["CPU"], f"no support on {REAL_DEV}")
|
||||
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
|
||||
def test_approx_jit_timeout(self, op):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
@@ -791,7 +791,7 @@ class TestJitGraphSplit(unittest.TestCase):
|
||||
hcqgraph=[self.ji_graph(5)])
|
||||
|
||||
def test_jit_multidev_xfer(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU is not a valid default device for this test (zero-copies)")
|
||||
if Device.DEFAULT == "METAL" or REAL_DEV == "METAL": raise unittest.SkipTest("Metal is flaky, with multidevice (same as metal llama 4gpu?)")
|
||||
|
||||
try: Device[f"{Device.DEFAULT}:1"]
|
||||
@@ -816,7 +816,7 @@ class TestJitGraphSplit(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(getenv("MOCKGPU"), "MockGPU does not support parallel copies")
|
||||
def test_jit_multidev_copy(self):
|
||||
if Device.DEFAULT in {"CPU", "LLVM"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
if Device.DEFAULT in {"CPU"}: raise unittest.SkipTest("CPU/LLVM is not a valid default device for this test (zero-copies)")
|
||||
|
||||
@TinyJit
|
||||
def f(inp):
|
||||
|
||||
@@ -373,7 +373,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_allclose(y.numpy(), y_shard.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
# NOTE: this is failing on LLVM CI, no idea why. Works locally.
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "LLVM", "CPU", "AMD"), "slow, and flaky on LLVM/CPU")
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
|
||||
def test_data_parallel_resnet(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
|
||||
@@ -409,7 +409,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
# sometimes there is zeros in these grads... why?
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "LLVM", "CPU", "AMD"), "slow, and flaky on LLVM/CPU")
|
||||
@unittest.skipIf(CI and REAL_DEV in ("CUDA", "NV", "CPU", "AMD"), "slow, and flaky on CPU")
|
||||
def test_data_parallel_resnet_train_step(self):
|
||||
from extra.models.resnet import ResNet18
|
||||
fake_image = Tensor.rand((2, 3, 224//16, 224//16))
|
||||
|
||||
@@ -2,7 +2,7 @@ import time, math, unittest, functools, platform, warnings
|
||||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, AMD_LLVM
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL, CPU_LLVM, AMD_LLVM
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
@@ -1304,7 +1304,7 @@ class TestOps(unittest.TestCase):
|
||||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "GPU", "CUDA"] or (Device.DEFAULT == "CPU" and CPU_LLVM) or IMAGE
|
||||
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
|
||||
def test_gemm_fp16(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
|
||||
@@ -2262,7 +2262,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w,stride=2),
|
||||
lambda x,w: Tensor.conv2d(x,w,stride=2))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "LLVM", "DEVECTORIZE=0 only for LLVM")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CPU" and CPU_LLVM, "DEVECTORIZE=0 only for LLVM")
|
||||
def test_strided_conv2d_simple_vec(self):
|
||||
with Context(DEVECTORIZE=0): self.test_strided_conv2d_simple()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.helpers import RANGEIFY
|
||||
from tinygrad.helpers import RANGEIFY, CPU_LLVM
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.engine.realize import get_program
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestOpts(unittest.TestCase):
|
||||
out = (a+b).contiguous(arg=opts)
|
||||
s = out.schedule()
|
||||
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
|
||||
if Device.DEFAULT in {"CPU", "GPU", "METAL"}:
|
||||
if Device.DEFAULT in {"CPU", "GPU", "METAL"} and not CPU_LLVM:
|
||||
prg = get_program(s[-1].ast)
|
||||
self.assertIn('float4', prg.src)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ def helper_profile_filter_device(profile, device:str):
|
||||
return [x for x in profile if getattr(x, "device", None) == device], dev_events[0]
|
||||
|
||||
# TODO: support in HCQCompiled
|
||||
is_cpu_hcq = Device.DEFAULT in {"CPU", "LLVM"}
|
||||
is_cpu_hcq = Device.DEFAULT in {"CPU"}
|
||||
|
||||
@unittest.skipUnless((issubclass(type(Device[Device.DEFAULT]), HCQCompiled) and not is_cpu_hcq) or Device.DEFAULT in {"METAL"}, "Dev not supported")
|
||||
class TestProfiler(unittest.TestCase):
|
||||
|
||||
@@ -13,7 +13,7 @@ def time_tensor_numpy(out:Tensor):
|
||||
|
||||
N = 4096
|
||||
class TestZeroCopy(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "LLVM", "METAL"}, "device isn't zero copy")
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "METAL"}, "device isn't zero copy")
|
||||
def test_zero_copy_from_default_to_cpu(self):
|
||||
demo = Tensor.rand(1).realize()
|
||||
t1 = time_tensor_numpy(demo)
|
||||
|
||||
@@ -3,14 +3,14 @@ from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Any, Generic, TypeVar, Iterator
|
||||
import importlib, inspect, functools, pathlib, os, platform, contextlib, sys, re, atexit, pickle, decimal
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, \
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, PROFILE, temp, colored, CPU_LLVM, \
|
||||
Context, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, cpu_events, ProfileEvent, ProfilePointEvent, dedup
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"]
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "DSP", "WEBGPU"]
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
@@ -303,7 +303,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if dtype == dtypes.bfloat16:
|
||||
if device == "METAL": return not CI
|
||||
if device in {"CUDA", "NV"}: return not CI and not getenv("PTX")
|
||||
if device in {"CPU", "LLVM"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
if device in {"CPU"}: return not CI and platform.machine() in {"arm", "arm64", "aarch64", "x86_64", "amd64"}
|
||||
return device in {"AMD", "PYTHON"}
|
||||
if dtype in dtypes.fp8s:
|
||||
# not supported yet - in progress
|
||||
@@ -317,7 +317,7 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
if dtype == dtypes.half:
|
||||
if device == "GPU": return not CI and not OSX
|
||||
if device in ["CUDA", "NV"]: return not CI
|
||||
if device == "LLVM": return OSX
|
||||
if device == "CPU" and CPU_LLVM: return OSX
|
||||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
@@ -44,7 +44,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
|
||||
match ji.prg:
|
||||
case CompiledRunner(): ji_graph_dev = ji.prg.dev
|
||||
case BufferXfer(): ji_graph_dev = Device[unwrap(ji.bufs[0]).device]
|
||||
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device not in {"CPU", "LLVM"}), None)
|
||||
case BufferCopy(): ji_graph_dev = next((Device[unwrap(b).device] for b in ji.bufs if unwrap(b).device != "CPU"), None)
|
||||
case ViewOp(): continue # ViewOps are just ignored
|
||||
case _: ji_graph_dev = None # Everything else is not graphed and flushes existing graph if it's being constructed
|
||||
|
||||
|
||||
@@ -139,10 +139,11 @@ DISABLE_COMPILER_CACHE, BLOCK_REORDER = ContextVar("DISABLE_COMPILER_CACHE", 0),
|
||||
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
||||
QUANTIZE, VALIDATE_WITH_CPU, DISABLE_FAST_IDIV = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("DISABLE_FAST_IDIV", 0)
|
||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE, AMD_LLVM = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0), ContextVar("AMD_LLVM", 1)
|
||||
ALLOW_DEVICE_USAGE, MAX_BUFFER_SIZE = ContextVar("ALLOW_DEVICE_USAGE", 1), ContextVar("MAX_BUFFER_SIZE", 0)
|
||||
RANGEIFY, FUSE_ATTENTION = ContextVar("RANGEIFY", 0), ContextVar("FUSE_ATTENTION", 0)
|
||||
EMULATE = ContextVar("EMULATE", "")
|
||||
CPU_COUNT = ContextVar("CPU_COUNT", max(1, (os.cpu_count() or 1) // (4 if ARCH_X86 else 2))) # take 1/2 of the cores, accounting HT
|
||||
CPU_LLVM, AMD_LLVM = ContextVar("CPU_LLVM", 0), ContextVar("AMD_LLVM", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -117,7 +117,7 @@ base_rewrite = PatternMatcher([
|
||||
])
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
device = "CPU"
|
||||
abi = 'win64cc' if sys.platform == 'win32' else None
|
||||
supports_float4 = True
|
||||
has_local = False
|
||||
@@ -173,7 +173,7 @@ class LLVMRenderer(Renderer):
|
||||
elif u.op in (Ops.DEFINE_LOCAL, Ops.DEFINE_REG):
|
||||
r[u] = f"%{'local' if u.op is Ops.DEFINE_LOCAL else 'reg'}_{str(u.arg).replace('(', '').replace(')', '').replace(',', '_').replace(' ', '')}"
|
||||
assert isinstance(u.dtype, PtrDType)
|
||||
if self.device == "LLVM" or u.op is Ops.DEFINE_REG:
|
||||
if self.device == "CPU" or u.op is Ops.DEFINE_REG:
|
||||
kernel.append(f" {r[u]} = alloca [{u.dtype.size} x {ldt(u.dtype.base)}]")
|
||||
else:
|
||||
local_args.append(f"@{r[u][1:]} = internal unnamed_addr addrspace(3) global [{u.dtype.size} x {ldt(u.dtype)}] undef, align 16")
|
||||
|
||||
@@ -1,32 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import platform, subprocess, sys, ctypes, functools, time, mmap, threading, queue
|
||||
from tinygrad.helpers import capstone_flatdump, getenv, from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, suppress_finalizing
|
||||
from tinygrad.device import Compiler, BufferSpec, DMACPURef
|
||||
import platform, sys, ctypes, functools, time, mmap, threading, queue
|
||||
from tinygrad.helpers import from_mv, to_mv, OSX, WIN, mv_address, wait_cond, cpu_profile, CPU_LLVM, suppress_finalizing
|
||||
from tinygrad.device import BufferSpec, DMACPURef
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocatorBase, HCQBuffer, HWQueue, HCQArgsState, HCQSignal, HCQProgram, MMIOInterface
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
from tinygrad.runtime.support.compiler_cpu import HostLLVMCompiler, ClangJITCompiler
|
||||
from tinygrad.uop.ops import sint
|
||||
|
||||
class CPUSignal(HCQSignal):
|
||||
def _sleep(self, time_spent_waiting_ms:int):
|
||||
if self.is_timeline and self.owner is not None: self.owner.tasks.join()
|
||||
|
||||
class ClangJITCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
# -fno-math-errno is required for __builtin_sqrt to become an instruction instead of a function call
|
||||
# x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm, don't use it
|
||||
target = 'x86_64' if sys.platform == 'win32' else platform.machine()
|
||||
# on arm march means "runs on this arch and superset" instead of "optimize for this arch". x86 march == arm mcpu
|
||||
arch = '-march=native' if platform.machine() in ('x86_64', 'AMD64') else '-mcpu=native'
|
||||
args = [arch, f'--target={target}-none-unknown-elf', '-O2', '-fPIC', '-ffreestanding', '-fno-math-errno', '-nostdlib', '-fno-ident']
|
||||
arch_args = ['-ffixed-x18'] if target == 'arm64' else []
|
||||
obj = subprocess.check_output([getenv("CC", 'clang'), '-c', '-x', 'c', *args, *arch_args, '-', '-o', '-'], input=src.encode('utf-8'))
|
||||
return jit_loader(obj)
|
||||
|
||||
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
|
||||
|
||||
class CPUWorker(threading.Thread):
|
||||
def __init__(self, dev, tasks, thread_id):
|
||||
super().__init__()
|
||||
@@ -131,4 +116,5 @@ class CPUDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
super().__init__(device, CPUAllocator(self), ClangRenderer(), ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
super().__init__(device, CPUAllocator(self), LLVMRenderer() if CPU_LLVM else ClangRenderer(),
|
||||
HostLLVMCompiler() if CPU_LLVM else ClangJITCompiler(), functools.partial(CPUProgram, self), CPUSignal, CPUComputeQueue)
|
||||
|
||||
@@ -9,7 +9,7 @@ try:
|
||||
assert comgr.AMD_COMGR_LANGUAGE_HIP == 3
|
||||
except AttributeError: pass # ignore if ROCm isn't installed
|
||||
from tinygrad.device import Compiler, CompileError
|
||||
from tinygrad.runtime.ops_llvm import LLVMCompiler
|
||||
from tinygrad.runtime.support.compiler_cpu import LLVMCompiler
|
||||
from tinygrad.helpers import OSX, to_char_p_p
|
||||
|
||||
def amdgpu_disassemble(lib:bytes):
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
import ctypes, platform, functools, queue
|
||||
import ctypes, platform, sys, subprocess
|
||||
from tinygrad.device import Compiler
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQSignal
|
||||
from tinygrad.runtime.ops_cpu import CPUAllocator, CPUProgram, CPUComputeQueue, CPUWorker
|
||||
from tinygrad.helpers import OSX, getenv, capstone_flatdump, DEBUG
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
import tinygrad.runtime.autogen.llvm as llvm
|
||||
from tinygrad.runtime.support.elf import jit_loader
|
||||
try: import tinygrad.runtime.autogen.llvm as llvm
|
||||
except (ImportError, FileNotFoundError): llvm = None #type:ignore[assignment]
|
||||
|
||||
class ClangJITCompiler(Compiler):
|
||||
def __init__(self, cachekey="compile_clang_jit"): super().__init__(cachekey)
|
||||
|
||||
def compile(self, src:str) -> bytes:
|
||||
# -fno-math-errno is required for __builtin_sqrt to become an instruction instead of a function call
|
||||
# x18 is a reserved platform register. It is clobbered on context switch in macos and is used to store TEB pointer in windows on arm, don't use it
|
||||
target = 'x86_64' if sys.platform == 'win32' else platform.machine()
|
||||
# on arm march means "runs on this arch and superset" instead of "optimize for this arch". x86 march == arm mcpu
|
||||
arch = '-march=native' if platform.machine() in ('x86_64', 'AMD64') else '-mcpu=native'
|
||||
args = [arch, f'--target={target}-none-unknown-elf', '-O2', '-fPIC', '-ffreestanding', '-fno-math-errno', '-nostdlib', '-fno-ident']
|
||||
arch_args = ['-ffixed-x18'] if target == 'arm64' else []
|
||||
obj = subprocess.check_output([getenv("CC", 'clang'), '-c', '-x', 'c', *args, *arch_args, '-', '-o', '-'], input=src.encode('utf-8'))
|
||||
return jit_loader(obj)
|
||||
|
||||
def disassemble(self, lib:bytes): return capstone_flatdump(lib)
|
||||
|
||||
def cerr(): return ctypes.pointer(ctypes.pointer(ctypes.c_char()))
|
||||
|
||||
@@ -70,9 +84,3 @@ class HostLLVMCompiler(LLVMCompiler):
|
||||
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
|
||||
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
|
||||
super().__init__(cpu.decode(), feats.decode())
|
||||
|
||||
class LLVMDevice(HCQCompiled):
|
||||
def __init__(self, device:str=""):
|
||||
self.tasks:queue.Queue = queue.Queue()
|
||||
CPUWorker(self, self.tasks, thread_id=0).start()
|
||||
super().__init__(device, CPUAllocator(self), LLVMRenderer(), HostLLVMCompiler(), functools.partial(CPUProgram, self), HCQSignal, CPUComputeQueue)
|
||||
@@ -450,7 +450,7 @@ class HCQCompiled(Compiled, Generic[SignalType]):
|
||||
raise RuntimeError(f"{errs}\nNo interface for {type(self).__name__[:-6]}:{self.device_id} is available:{err_short}\n" \
|
||||
f"\nForce an interface with {type(self).__name__[:-6].upper()}_IFACE={('|'.join(x.__name__[:-5] for x in ifaces))}.")
|
||||
|
||||
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] in ("CPU", "LLVM")
|
||||
def _is_cpu(self) -> bool: return hasattr(self, 'device') and self.device.split(":")[0] == "CPU"
|
||||
|
||||
def finalize(self):
|
||||
try: self.synchronize() # Try to finalize device in any case.
|
||||
|
||||
@@ -235,7 +235,7 @@ def get_disassembly(ctx:list[str]):
|
||||
lib = (compiler:=Device[prg.device].compiler).compile(prg.src)
|
||||
with redirect_stdout(buf:=io.StringIO()): compiler.disassemble(lib)
|
||||
disasm_str = buf.getvalue()
|
||||
from tinygrad.runtime.ops_llvm import llvm, LLVMCompiler
|
||||
from tinygrad.runtime.support.compiler_cpu import llvm, LLVMCompiler
|
||||
if isinstance(compiler, LLVMCompiler):
|
||||
mtriple = ctypes.string_at(llvm.LLVMGetTargetMachineTriple(tm:=compiler.target_machine)).decode()
|
||||
mcpu = ctypes.string_at(llvm.LLVMGetTargetMachineCPU(tm)).decode()
|
||||
|
||||
Reference in New Issue
Block a user