mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
CLANG -> CPU (#9189)
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Test AMX tensor cores
|
||||
run: |
|
||||
DEBUG=2 CLANG=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
DEBUG=2 CPU=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
DEBUG=2 LLVM=1 AMX=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Run Tensor Core GEMM (float)
|
||||
run: DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
|
||||
32
.github/workflows/test.yml
vendored
32
.github/workflows/test.yml
vendored
@@ -81,7 +81,7 @@ jobs:
|
||||
run: DEBUG=100 python3 -c "from tinygrad import Tensor; N = 1024; a, b = Tensor.rand(N, N), Tensor.rand(N, N); c = (a.reshape(N, 1, N) * b.T.reshape(1, N, N)).sum(axis=2); print((c.numpy() - (a.numpy() @ b.numpy())).mean())"
|
||||
- name: Compile EfficientNet to C and test it
|
||||
run: |
|
||||
CLANG=1 PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
|
||||
CPU=1 PYTHONPATH="." python examples/compile_efficientnet.py > recognize.c
|
||||
clang -O2 recognize.c -lm -o recognize
|
||||
cat test/models/efficientnet/Chicken.jpg | ./recognize | grep cock
|
||||
|
||||
@@ -355,13 +355,13 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test ONNX (GPU)
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX (CLANG)
|
||||
run: CLANG=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX (CPU)
|
||||
run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX (LLVM)
|
||||
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Run CLOUD=1 Test
|
||||
run: |
|
||||
CLOUDDEV=CLANG CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=GPU CLOUD=1 python3 test/test_tiny.py
|
||||
CLOUDDEV=GPU IMAGE=2 CLOUD=1 python3 test/test_tiny.py
|
||||
- name: Test Optimization Helpers
|
||||
@@ -378,7 +378,7 @@ jobs:
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testmodels:
|
||||
name: Models (llvm+clang+gpu)
|
||||
name: Models (llvm+cpu+gpu)
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
@@ -395,8 +395,8 @@ jobs:
|
||||
run: LLVM=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (gpu)
|
||||
run: GPU=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (clang)
|
||||
run: CLANG=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: CPU=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
@@ -431,8 +431,8 @@ jobs:
|
||||
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
|
||||
- name: Test LLVM=1 DEVECTORIZE=0
|
||||
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
- name: Test CLANG=1 DEVECTORIZE=0
|
||||
run: CLANG=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
- name: Test CPU=1 DEVECTORIZE=0
|
||||
run: CPU=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
|
||||
testwebgpu:
|
||||
name: Linux (WebGPU)
|
||||
@@ -464,7 +464,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, clang, gpu, ptx, amd, nv] #, triton]
|
||||
backend: [llvm, cpu, gpu, ptx, amd, nv] #, triton]
|
||||
|
||||
name: Linux (${{ matrix.backend }})
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -482,10 +482,10 @@ jobs:
|
||||
amd: ${{ matrix.backend == 'amd' && 'true' }}
|
||||
cuda: ${{ (matrix.backend == 'ptx' || matrix.backend == 'triton' || matrix.backend == 'nv') && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nPTX=1\nMOCKGPU=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nNV=1\nMOCKGPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas' || matrix.backend == 'amd' && 'AMD=1\nMOCKGPU=1\nFORWARD_ONLY=1' || matrix.backend == 'nv' && 'NV=1\nMOCKGPU=1\nFORWARD_ONLY=1' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU','AMD','NV'], Device.DEFAULT"
|
||||
PYTHONPATH=${{ github.workspace }} python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CPU','CUDA','GPU','AMD','NV'], Device.DEFAULT"
|
||||
DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run pytest (not cuda or amd)
|
||||
if: matrix.backend!='ptx' && matrix.backend!='triton' && matrix.backend != 'amd' && matrix.backend != 'nv'
|
||||
@@ -582,7 +582,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [metal, llvm, clang]
|
||||
backend: [metal, llvm, cpu]
|
||||
name: MacOS (${{ matrix.backend }})
|
||||
runs-on: macos-15
|
||||
timeout-minutes: 10
|
||||
@@ -596,7 +596,7 @@ jobs:
|
||||
deps: testing_minimal
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'metal' && 'METAL=1\nJIT=2'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'metal' && 'METAL=1\nJIT=2'}}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python -c "from tinygrad import Device; assert Device.DEFAULT == '${{ matrix.backend }}'.upper(), Device.DEFAULT"
|
||||
@@ -612,7 +612,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, clang]
|
||||
backend: [llvm, cpu]
|
||||
|
||||
name: Windows (${{ matrix.backend }})
|
||||
runs-on: windows-latest
|
||||
@@ -627,7 +627,7 @@ jobs:
|
||||
deps: testing_unit
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1'}}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_elf.py --ignore=test/unit/test_tar.py
|
||||
|
||||
@@ -81,7 +81,7 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
|
||||
tinygrad already supports numerous accelerators, including:
|
||||
|
||||
- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
|
||||
- [x] [CLANG (C Code)](tinygrad/runtime/ops_clang.py)
|
||||
- [x] [CPU (C Code)](tinygrad/runtime/ops_cpu.py)
|
||||
- [x] [LLVM](tinygrad/runtime/ops_llvm.py)
|
||||
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
||||
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
print("******** first, the runtime ***********")
|
||||
|
||||
from tinygrad.runtime.ops_clang import ClangJITCompiler, MallocAllocator, CPUProgram
|
||||
from tinygrad.runtime.ops_cpu import ClangJITCompiler, MallocAllocator, CPUProgram
|
||||
|
||||
# allocate some buffers
|
||||
out = MallocAllocator.alloc(4)
|
||||
@@ -34,7 +34,7 @@ assert val == 5
|
||||
|
||||
print("******** second, the Device ***********")
|
||||
|
||||
DEVICE = "CLANG" # NOTE: you can change this!
|
||||
DEVICE = "CPU" # NOTE: you can change this!
|
||||
|
||||
import struct
|
||||
from tinygrad.dtype import dtypes
|
||||
@@ -90,7 +90,7 @@ out = a.alu(Ops.ADD, b)
|
||||
|
||||
# schedule the computation as a list of kernels
|
||||
sched, _, becomes_map = create_schedule_with_vars(out.sink())
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CPU
|
||||
# NOTE: UOps are no longer mutable, the scheduler gives you a map to lookup which BUFFER the result was written to
|
||||
out = becomes_map[out]
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ The `Allocator` class is responsible for managing memory on the device. There is
|
||||
|
||||
The `Program` class is created for each loaded program. It is responsible for executing the program on the device. As an example, here is a `CPUProgram` implementation which loads program and runs it.
|
||||
|
||||
::: tinygrad.runtime.ops_clang.CPUProgram
|
||||
::: tinygrad.runtime.ops_cpu.CPUProgram
|
||||
options:
|
||||
members: true
|
||||
|
||||
|
||||
@@ -31,13 +31,13 @@ These control the behavior of core tinygrad even when used as a library.
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
DEBUG | [1-6] | enable debugging output, with 4 you get operations, timings, speed, generated code and more
|
||||
GPU | [1] | enable the GPU backend
|
||||
GPU | [1] | enable the GPU (OpenCL) backend
|
||||
CUDA | [1] | enable CUDA backend
|
||||
AMD | [1] | enable AMD backend
|
||||
NV | [1] | enable NV backend
|
||||
METAL | [1] | enable Metal backend (for Mac M1 and after)
|
||||
METAL_XCODE | [1] | enable Metal using macOS Xcode SDK
|
||||
CLANG | [1] | enable Clang backend
|
||||
CPU | [1] | enable CPU (Clang) backend
|
||||
LLVM | [1] | enable LLVM backend
|
||||
BEAM | [#] | number of beams in kernel beam search
|
||||
DEFAULT_FLOAT | [HALF, ...]| specify the default float dtype (FLOAT32, HALF, BFLOAT16, FLOAT64, ...), default to FLOAT32
|
||||
|
||||
@@ -17,7 +17,7 @@ from tinygrad import Device
|
||||
print(Device.DEFAULT)
|
||||
```
|
||||
|
||||
You will see `CUDA` here on a GPU instance, or `CLANG` here on a CPU instance.
|
||||
You will see `CUDA` here on a GPU instance, or `CPU` here on a CPU instance.
|
||||
|
||||
## A simple model
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Runtimes
|
||||
|
||||
tinygrad supports various runtimes, enabling your code to scale across a wide range of devices. The default runtime can be automatically selected based on the available hardware, or you can force a specific runtime to be default using environment variables (e.g., `CLANG=1`).
|
||||
tinygrad supports various runtimes, enabling your code to scale across a wide range of devices. The default runtime can be automatically selected based on the available hardware, or you can force a specific runtime to be default using environment variables (e.g., `CPU=1`).
|
||||
|
||||
| Runtime | Description | Requirements |
|
||||
|---------|-------------|--------------|
|
||||
@@ -10,6 +10,6 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
|
||||
| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | M1+ Macs; Metal 3.0+ for `bfloat` support |
|
||||
| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | NVIDIA GPU with CUDA support |
|
||||
| [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device |
|
||||
| [CLANG (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_clang.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
|
||||
| [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
|
||||
| [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable |
|
||||
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.1.6). |
|
||||
|
||||
@@ -15,9 +15,9 @@ if __name__ == "__main__":
|
||||
if getenv("WEBGPU"):
|
||||
safe_save(get_state_dict(model), (dirname / "net.safetensors").as_posix())
|
||||
load_state_dict(model, safe_load(str(dirname / "net.safetensors")))
|
||||
mode = "clang" if getenv("CLANG", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
mode = "clang" if getenv("CPU", "") != "" else "webgpu" if getenv("WEBGPU", "") != "" else ""
|
||||
prg, inp_sizes, out_sizes, state = export_model(model, mode, Tensor.randn(1,3,224,224))
|
||||
if getenv("CLANG", "") == "":
|
||||
if getenv("CPU", "") == "":
|
||||
ext = "js" if getenv("WEBGPU", "") != "" else "json"
|
||||
with open(dirname / f"net.{ext}", "w") as text_file:
|
||||
text_file.write(prg)
|
||||
@@ -68,6 +68,6 @@ if __name__ == "__main__":
|
||||
else printf("%s\\n", lbls[best_idx]);
|
||||
}""")
|
||||
|
||||
# CLANG=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# CPU=1 python3 examples/compile_efficientnet.py | clang -O2 -lm -x c - -o recognize && DEBUG=1 time ./recognize docs/showcase/stable_diffusion_by_tinygrad.jpg
|
||||
# category : 281 (tabby, tabby cat) with 9.452788
|
||||
print('\n'.join(cprog))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# An example to compile a small Tensorflow model to extremely portable C code
|
||||
|
||||
import os, sys
|
||||
os.environ["CLANG"] = '1'
|
||||
os.environ["CPU"] = '1'
|
||||
os.environ["JIT"] = '2'
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import os
|
||||
if "NOOPT" not in os.environ: os.environ["NOOPT"] = "1"
|
||||
from tinygrad import Device, nn, Tensor, dtypes, Variable
|
||||
Device.DEFAULT = "CLANG"
|
||||
Device.DEFAULT = "CPU"
|
||||
from train_gpt2 import GPT, GPTConfig
|
||||
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name
|
||||
from tinygrad.engine.realize import get_kernel, run_schedule
|
||||
@@ -43,9 +43,9 @@ if __name__ == "__main__":
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast.op is Ops.SINK])
|
||||
srcs = {}
|
||||
for ast in ast_dedup:
|
||||
k = get_kernel(Device["CLANG"].renderer, ast)
|
||||
k = get_kernel(Device["CPU"].renderer, ast)
|
||||
k.linearize()
|
||||
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
|
||||
src = Device["CPU"].renderer.render(to_function_name(k.name), k.uops)
|
||||
srcs[ast] = (k.name, src)
|
||||
print("functions:", len(srcs))
|
||||
used_buffers = dedup(flatten([si.bufs for si in sched]))
|
||||
|
||||
@@ -170,13 +170,13 @@ def batch_load_resnet(batch_size=64, val=False, shuffle=True, seed=None, pad_fir
|
||||
|
||||
def process_batch_bert(data: List[dict]) -> dict[str, Tensor]:
|
||||
return {
|
||||
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CLANG"),
|
||||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_ids": Tensor(np.concatenate([s["input_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"input_mask": Tensor(np.concatenate([s["input_mask"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"segment_ids": Tensor(np.concatenate([s["segment_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_positions": Tensor(np.concatenate([s["masked_lm_positions"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_ids": Tensor(np.concatenate([s["masked_lm_ids"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_weights": Tensor(np.concatenate([s["masked_lm_weights"] for s in data], axis=0), dtype=dtypes.float32, device="CPU"),
|
||||
"next_sentence_labels": Tensor(np.concatenate([s["next_sentence_labels"] for s in data], axis=0), dtype=dtypes.int32, device="CPU"),
|
||||
}
|
||||
|
||||
def load_file(file: str):
|
||||
|
||||
@@ -222,11 +222,11 @@ def get_mlperf_bert_model():
|
||||
|
||||
def get_fake_data_bert(BS:int):
|
||||
return {
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CLANG"),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CLANG"),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CLANG"),
|
||||
"input_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"input_mask": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"segment_ids": Tensor.empty((BS, 512), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.int32, device="CPU"),
|
||||
"masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32, device="CPU"),
|
||||
"next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.int32, device="CPU"),
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_clang import ClangProgram
|
||||
from tinygrad.runtime.ops_cpu import ClangProgram
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
||||
@@ -30,7 +30,7 @@ class ClangGraph(GraphRunner):
|
||||
code.append(f" {cast(CompiledRunner, ji.prg).p.function_name}({','.join(args)});")
|
||||
code.append("}")
|
||||
if DEBUG >= 4: print("\n".join(code))
|
||||
compiler = Device["CLANG"].compiler
|
||||
compiler = Device["CPU"].compiler
|
||||
assert compiler is not None
|
||||
self._prg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.helpers import Context
|
||||
from tinygrad.dtype import dtypes
|
||||
import json
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CLANG", "CUDA", "GPU"]
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
@@ -191,7 +191,7 @@ export default {exported_name};
|
||||
"""
|
||||
|
||||
def export_model(model, target:str, *inputs, model_name: Optional[str] = None):
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CLANG, CUDA, GPU, METAL are supported"
|
||||
assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, "only WEBGPU, CPU, CUDA, GPU, METAL are supported"
|
||||
with Context(JIT=2): run,special_names = jit_model(model, *inputs)
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
state = get_state_dict(model)
|
||||
|
||||
@@ -32,8 +32,8 @@ import os
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
# define the compute
|
||||
A = Tensor.rand(M, K, device="clang")
|
||||
B = Tensor.rand(K, N, device="clang")
|
||||
A = Tensor.rand(M, K, device="CPU")
|
||||
B = Tensor.rand(K, N, device="CPU")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
sched = C.schedule()
|
||||
@@ -42,6 +42,6 @@ from tinygrad.device import CompilerOptions
|
||||
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
|
||||
#lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
from tinygrad.runtime.ops_clang import renderer
|
||||
from tinygrad.runtime.ops_cpu import renderer
|
||||
src = renderer("mmult", lin.uops)
|
||||
print(src)
|
||||
|
||||
2
test/external/external_model_benchmark.py
vendored
2
test/external/external_model_benchmark.py
vendored
@@ -138,7 +138,7 @@ def assert_allclose(tiny_out:dict, onnx_out:dict, rtol=1e-5, atol=1e-5):
|
||||
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tiny_out.keys()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CLANG"]
|
||||
devices = [Device.DEFAULT] if getenv("NOCLANG") else [Device.DEFAULT, "CPU"]
|
||||
if getenv("MODEL", "") != "": benchmark_model(getenv("MODEL", ""), devices, True)
|
||||
else:
|
||||
for m in MODELS: benchmark_model(m, devices, True)
|
||||
|
||||
4
test/external/external_multi_gpu.py
vendored
4
test/external/external_multi_gpu.py
vendored
@@ -19,8 +19,8 @@ if __name__ == "__main__":
|
||||
with Timing("GPU initial sync: "): sync()
|
||||
|
||||
with Timing("CPU creation: ", on_exit=lambda x: f", {(sz*4*2)/x:.2f} GB/sec"):
|
||||
c0 = (Tensor.ones(sz, device="clang")/2).realize()
|
||||
c1 = (Tensor.ones(sz, device="clang")/4).realize()
|
||||
c0 = (Tensor.ones(sz, device="CPU")/2).realize()
|
||||
c1 = (Tensor.ones(sz, device="CPU")/4).realize()
|
||||
print(c0.lazydata.base.realized)
|
||||
print(c1.lazydata.base.realized)
|
||||
|
||||
|
||||
4
test/external/external_test_example.py
vendored
4
test/external/external_test_example.py
vendored
@@ -23,10 +23,10 @@ def multidevice_test(fxn):
|
||||
|
||||
class TestExample(unittest.TestCase):
|
||||
@multidevice_test
|
||||
def test_convert_to_clang(self, device):
|
||||
def test_convert_to_cpu(self, device):
|
||||
a = Tensor([[1,2],[3,4]], device=device)
|
||||
assert a.numpy().shape == (2,2)
|
||||
b = a.to("CLANG")
|
||||
b = a.to("CPU")
|
||||
assert b.numpy().shape == (2,2)
|
||||
|
||||
@multidevice_test
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestIndexing(unittest.TestCase):
|
||||
# self.assertRaises(TypeError, delitem)
|
||||
|
||||
# TODO: LLVM is quite fast, why are other compiled backends slow?
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "GPU", "METAL", "NV", "AMD"], "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU", "GPU", "METAL", "NV", "AMD"], "slow")
|
||||
def test_advancedindex(self):
|
||||
# integer array indexing
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class TinyConvNet:
|
||||
x = x.reshape(shape=[x.shape[0], -1])
|
||||
return x.dot(self.l1)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow")
|
||||
class TestMNIST(unittest.TestCase):
|
||||
def test_sgd_onestep(self):
|
||||
np.random.seed(1337)
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
dtypes.default_float = self.old_float
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow, covered by METAL")
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow, covered by METAL")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need dtypes.float16")
|
||||
def test_stable_diffusion(self):
|
||||
params = unet_params
|
||||
@@ -95,7 +95,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
with Context(JIT=0): return model(t, v).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 137 if CI else 396, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CLANG", "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow")
|
||||
def test_train_mnist(self):
|
||||
from examples.beautiful_mnist import Model
|
||||
with Tensor.train():
|
||||
@@ -113,7 +113,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CLANG", "GPU", "LLVM"}, "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
|
||||
@@ -16,7 +16,7 @@ TRANSCRIPTION_2 = "a slightly longer audio file so that we can test batch transc
|
||||
TEST_FILE_3_URL = 'https://homepage.ntu.edu.tw/~karchung/miniconversations/mc45.mp3'
|
||||
TRANSCRIPTION_3 = "Just lie back and relax. Is the level of pressure about right? Yes, it's fine, and I'd like conditioner please. Sure. I'm going to start the second lathering now. Would you like some Q-tips? How'd you like it cut? I'd like my bangs and the back trimmed, and I'd like the rest thinned out a bit and layered. Where would you like the part? On the left, right about here. Here, have a look. What do you think? It's fine. Here's a thousand anti-dollars. It's 30-ant extra for the rants. Here's your change and receipt. Thank you, and please come again. So how do you like it? It could have been worse, but you'll notice that I didn't ask her for her card. Hmm, yeah. Maybe you can try that place over there next time." # noqa: E501
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CLANG"], "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU"], "slow")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16 support")
|
||||
class TestWhisper(unittest.TestCase):
|
||||
@classmethod
|
||||
|
||||
@@ -24,7 +24,7 @@ class TestCopySpeed(unittest.TestCase):
|
||||
s.unlink()
|
||||
|
||||
def testCopyCPUtoDefault(self):
|
||||
t = Tensor.ones(N, N, device="clang").contiguous().realize()
|
||||
t = Tensor.ones(N, N, device="CPU").contiguous().realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
@@ -35,7 +35,7 @@ class TestCopySpeed(unittest.TestCase):
|
||||
def testCopyCPUtoDefaultFresh(self):
|
||||
print("fresh copy")
|
||||
for _ in range(3):
|
||||
t = Tensor.ones(N, N, device="clang").contiguous().realize()
|
||||
t = Tensor.ones(N, N, device="CPU").contiguous().realize()
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"): # noqa: F821
|
||||
with Timing("queue: "):
|
||||
t.to(Device.DEFAULT).realize()
|
||||
@@ -47,14 +47,14 @@ class TestCopySpeed(unittest.TestCase):
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s"):
|
||||
t.to('clang').realize()
|
||||
t.to('CPU').realize()
|
||||
|
||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU")
|
||||
def testCopyCPUto6GPUs(self):
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||
t = Tensor.ones(N, N, device="clang").contiguous().realize()
|
||||
t = Tensor.ones(N, N, device="CPU").contiguous().realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
for _ in range(3):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"):
|
||||
|
||||
@@ -38,7 +38,7 @@ def apply(tor, ten, tor_fn, ten_fn=None):
|
||||
except: ten, ok = None, not ok # noqa: E722
|
||||
return tor, ten, ok
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CLANG", "NV"), "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("CPU", "NV"), "slow")
|
||||
class TestShapeOps(unittest.TestCase):
|
||||
@settings.get_profile(__file__)
|
||||
@given(st_shape(), st_int32, st.one_of(st_int32, st.lists(st_int32)))
|
||||
|
||||
@@ -22,7 +22,7 @@ def _simple_test(add, extract=lambda x: x, N=10):
|
||||
class TestJit(unittest.TestCase):
|
||||
|
||||
@settings(deadline=2e4)
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CPU"], f"no support on {Device.DEFAULT}")
|
||||
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
|
||||
def test_approx_jit_timeout(self, op):
|
||||
with Context(TRANSCENDENTAL=2):
|
||||
@@ -497,8 +497,8 @@ class TestCopyInsideJit(unittest.TestCase):
|
||||
@TinyJit
|
||||
def add(x,y) -> Tensor: return x.to(Device.DEFAULT)+y
|
||||
for _ in range(5):
|
||||
# create a Tensor in CLANG
|
||||
a = Tensor.rand(16,16,device="CLANG").realize()
|
||||
# create a Tensor on CPU
|
||||
a = Tensor.rand(16,16,device="CPU").realize()
|
||||
b = Tensor.rand(16,16).realize()
|
||||
out = add(a,b)
|
||||
np.testing.assert_allclose(out.flatten().tolist(), [x+y for x,y in zip(a.flatten().tolist(), b.flatten().tolist())])
|
||||
@@ -529,12 +529,12 @@ class TestJitPrune(unittest.TestCase):
|
||||
w2_prune = TinyJit(w2, prune=True)
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
a = Tensor.rand(16, device="CPU").realize()
|
||||
out = w2_noprune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
for _ in range(3):
|
||||
a = Tensor.rand(16, device="CLANG").realize()
|
||||
a = Tensor.rand(16, device="CPU").realize()
|
||||
out = w2_prune(a)
|
||||
np.testing.assert_allclose(out.tolist(), [x*2+y for x,y in zip(weights.tolist(), a.tolist())])
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad import Device
|
||||
|
||||
class TestKernelCache(unittest.TestCase):
|
||||
def test_kernel_cache_in_action(self):
|
||||
if Device.DEFAULT not in ["CLANG"]:
|
||||
if Device.DEFAULT not in ["CPU"]:
|
||||
self.skipTest("No custom kernel cache is implemented")
|
||||
|
||||
unique_const = 0.6765677269
|
||||
@@ -16,14 +16,14 @@ class TestKernelCache(unittest.TestCase):
|
||||
|
||||
a1 = Tensor.rand(4,4).realize()
|
||||
b1 = Tensor.rand(4,4).realize()
|
||||
orig_compile_func = Device['CLANG'].compiler
|
||||
Device['CLANG'].compiler = None # making it not callable
|
||||
orig_compile_func = Device['CPU'].compiler
|
||||
Device['CPU'].compiler = None # making it not callable
|
||||
|
||||
try:
|
||||
x1 = a1 + b1 + unique_const
|
||||
x1.realize() # Same kernel should be from cache.
|
||||
finally:
|
||||
Device['CLANG'].compiler = orig_compile_func
|
||||
Device['CPU'].compiler = orig_compile_func
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -64,7 +64,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
|
||||
class TestLinearizer(unittest.TestCase):
|
||||
def test_arg_dedup(self):
|
||||
# NOTE: this realize exists because Tensor.numpy calls .contiguous() internally
|
||||
# without contiguous folding, rand.to("CLANG") and rand.contiguous().to("CLANG") are different UOps.
|
||||
# without contiguous folding, rand.to("CPU") and rand.contiguous().to("CPU") are different UOps.
|
||||
# this test asserts they are the identical Buffer
|
||||
# having different buffers is fine for correctness, because the outputs match.
|
||||
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
|
||||
@@ -983,8 +983,8 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
# NOTE: can reenable, it does work. it just makes BEAM slow
|
||||
@unittest.expectedFailure
|
||||
@unittest.skipUnless(Device.DEFAULT == "CLANG", "test only for CLANG")
|
||||
def test_upcast_with_locals_clang(self):
|
||||
@unittest.skipUnless(Device.DEFAULT == "CPU", "test only for CPU")
|
||||
def test_upcast_with_locals_cpu(self):
|
||||
out = Tensor.ones(64,64).contiguous() @ Tensor.ones(64,64).contiguous()
|
||||
k = Kernel(out.schedule()[-1].ast)
|
||||
k.apply_opt(Opt(OptOps.LOCAL, axis=0, arg=4))
|
||||
@@ -1136,7 +1136,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert u.src[-1].src[0].op != Ops.ASSIGN
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi(self):
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
x, y = Tensor.rand(128, 128, dtype=tc.dtype_in), Tensor.rand(128, 128, dtype=tc.dtype_in)
|
||||
@@ -1148,7 +1148,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert u.src[-1].src[0].op != Ops.ASSIGN
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"}, "CLANG does not support using a different type for accumulation")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"}, "CPU does not support using a different type for accumulation")
|
||||
def test_tensor_cores_unroll_casted_phi_with_children(self):
|
||||
# all ASSIGN children are outside the loop
|
||||
tc = [tc for tc in Device[Device.DEFAULT].renderer.tensor_cores if tc.dtype_in != tc.dtype_out][0]
|
||||
@@ -1445,7 +1445,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (2, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim(self):
|
||||
a = Tensor.rand(2, 8).realize()
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
@@ -1462,7 +1462,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (4, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.rand(2, size).realize()
|
||||
@@ -1487,7 +1487,7 @@ class TestFloat4(unittest.TestCase):
|
||||
for i in range(len(sizes)):
|
||||
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_unaligned_load(self):
|
||||
a = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
b = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
@@ -1500,7 +1500,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (0, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load(self):
|
||||
a = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
@@ -1517,7 +1517,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
assert TestFloat4.count_float4(k) == (0, 2)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CLANG", "LLVM"} and AMX, "Only CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "Only CPU with AMX upcasts float up to size 16")
|
||||
def test_float4_multidim_unaligned_load_amx(self):
|
||||
def kernel_for_shape(size, shift):
|
||||
a = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
|
||||
@@ -498,7 +498,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
opts = [Opt(op=OptOps.PADTO, axis=0, arg=32)]
|
||||
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
|
||||
|
||||
#@unittest.skipIf(Device.DEFAULT in ("LLVM", "METAL", "CLANG"), "flaky")
|
||||
#@unittest.skipIf(Device.DEFAULT in ("LLVM", "METAL", "CPU"), "flaky")
|
||||
@unittest.skip("flaky everywhere")
|
||||
def test_failure_22(self):
|
||||
ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestPickle(unittest.TestCase):
|
||||
|
||||
def test_pickle_realized_tensor_alt(self):
|
||||
print("** init")
|
||||
t = Tensor.rand(10, 10).to("CLANG").realize()
|
||||
t = Tensor.rand(10, 10).to("CPU").realize()
|
||||
st = pickle.dumps(t)
|
||||
t_values = t.numpy()
|
||||
del t # free buffers
|
||||
@@ -50,7 +50,7 @@ class TestPickle(unittest.TestCase):
|
||||
|
||||
def test_pickle_realized_tensor_alt2(self):
|
||||
print("** init")
|
||||
t = Tensor.rand(10, 10).to("CLANG").realize()
|
||||
t = Tensor.rand(10, 10).to("CPU").realize()
|
||||
tensor_uop = t.lazydata
|
||||
assert tensor_uop.is_realized, f"expected {tensor_uop} to be realized"
|
||||
t_values = t.numpy()
|
||||
@@ -93,7 +93,7 @@ class TestPickle(unittest.TestCase):
|
||||
np.testing.assert_equal(vt2.numpy(), 20)
|
||||
|
||||
def test_pickle_buffer_view(self):
|
||||
t = Tensor.arange(10, device="CLANG").contiguous().realize()
|
||||
t = Tensor.arange(10, device="CPU").contiguous().realize()
|
||||
vt = t[3:5].contiguous().realize()
|
||||
assert hasattr(vt.lazydata.buffer, 'base')
|
||||
ref_value = vt.tolist()
|
||||
|
||||
@@ -239,7 +239,7 @@ class TestRandomness(unittest.TestCase):
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"),
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
|
||||
self.assertTrue(Tensor.randint(1, device="CPU").device=="CPU")
|
||||
# check types of args
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||
store = UOp.store(a.index(idx), alu)
|
||||
sink = UOp(Ops.SINK, dtypes.void, (store,))
|
||||
uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer))
|
||||
# CLANG doesn't use the max function
|
||||
# CPU doesn't use the max function
|
||||
ret = _test_uop_result([Tensor([1])], uops)[0]
|
||||
self.assertEqual(ret[0], 1)
|
||||
|
||||
|
||||
@@ -1732,16 +1732,16 @@ class TestIndexing(unittest.TestCase):
|
||||
self.assertIs(sched[1].ast.op, Ops.BUFFER_VIEW)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from ext device")
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from ext device")
|
||||
def test_arange_shrink_copy(self):
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG")
|
||||
a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CPU")
|
||||
sched = self.check_schedule(a, 1)
|
||||
self.assertIs(sched[-1].ast.op, Ops.COPY)
|
||||
np.testing.assert_equal(a.numpy(), [[4, 5]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from ext device")
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from ext device")
|
||||
def test_arange_expand_copy(self):
|
||||
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CLANG")
|
||||
a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).contiguous().to("CPU")
|
||||
sched = self.check_schedule(a, 1)
|
||||
self.assertIs(sched[1].ast.op, Ops.COPY)
|
||||
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.ADD)
|
||||
@@ -2279,23 +2279,23 @@ class TestConst(unittest.TestCase):
|
||||
run_schedule(sched, var_vals)
|
||||
self.assertEqual(a.tolist(), 3)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "tests copy from another device to clang")
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "tests copy from another device to cpu")
|
||||
class TestCopyFolding(unittest.TestCase):
|
||||
def test_const_copy_is_free(self):
|
||||
b = Tensor(1).to("CLANG")
|
||||
b = Tensor(1).to("CPU")
|
||||
check_schedule(b, 0, filter_sink=False)
|
||||
assert b.item() == 1
|
||||
|
||||
def test_late_const_copy_folding(self):
|
||||
a = Tensor.arange(3).realize()
|
||||
zeros = Tensor.zeros(3).realize()
|
||||
b = (a*zeros).to("CLANG")
|
||||
b = (a*zeros).to("CPU")
|
||||
run_schedule(check_schedule(b, 0, filter_sink=False))
|
||||
self.assertListEqual(b.tolist(), [0, 0, 0])
|
||||
|
||||
def test_alu_after_copy(self):
|
||||
a = Tensor.ones((4,)).to("CLANG").lazydata
|
||||
b = Tensor.empty(4, device="CLANG").lazydata
|
||||
a = Tensor.ones((4,)).to("CPU").lazydata
|
||||
b = Tensor.empty(4, device="CPU").lazydata
|
||||
add = a+b
|
||||
add = schedule_graph_rewrite(add)
|
||||
assert all_same([x.device for x in add.src]), f"ALU has different devices! {[x.device for x in add.src]}"
|
||||
@@ -2348,13 +2348,13 @@ class TestCopyFolding(unittest.TestCase):
|
||||
def test_permute_on_disk(self):
|
||||
with open(temp('dt_arange_4_permute'), "wb") as f: f.write(Tensor.arange(4).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(4, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_4_permute')}")
|
||||
b = a.reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b = a.reshape(2, 2).permute(1, 0).to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
def test_permute_after_shrink(self):
|
||||
a = Tensor.arange(5)
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
@@ -2364,7 +2364,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
def test_permute_after_shrink_on_disk(self):
|
||||
with open(temp('dt_arange_5_permute'), "wb") as f: f.write(Tensor.arange(5).realize().lazydata.base.buffer.as_buffer())
|
||||
a = Tensor.empty(5, dtype=dtypes.int32, device=f"disk:{temp('dt_arange_5_permute')}")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CLANG")
|
||||
b = a.shrink(((0, 4),)).reshape(2, 2).permute(1, 0).to("CPU")
|
||||
b.realize()
|
||||
self.assertListEqual(b.tolist(), [[0, 2], [1, 3]])
|
||||
|
||||
|
||||
@@ -247,7 +247,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
assert a.shape == b.shape, f"shape mismatch {a.shape} != {b.shape}"
|
||||
|
||||
def test_rand_like_device(self):
|
||||
a = Tensor.ones(3, 3, device="CLANG")
|
||||
a = Tensor.ones(3, 3, device="CPU")
|
||||
b = Tensor.rand_like(a)
|
||||
self.assertEqual(b.device, a.device)
|
||||
|
||||
@@ -326,7 +326,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
def test_tensor_from_blob(self):
|
||||
x = memoryview(bytearray(16)).cast('I')
|
||||
|
||||
t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CLANG")
|
||||
t = Tensor.from_blob(mv_address(x), (4,), dtype=dtypes.int, device="CPU")
|
||||
z = (t+1)
|
||||
np.testing.assert_equal(z.numpy(), [1, 1, 1, 1])
|
||||
|
||||
@@ -695,7 +695,7 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
class TestTensorCreationDevice(unittest.TestCase):
|
||||
# test auxiliary tensors are created on the same device
|
||||
def test_one_hot(self):
|
||||
y = Tensor([1, 2, 3]).to("CLANG")
|
||||
y = Tensor([1, 2, 3]).to("CPU")
|
||||
x = y.one_hot(10)
|
||||
x.realize()
|
||||
|
||||
|
||||
@@ -661,7 +661,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
sink = float4_rewrite(sink.sink())
|
||||
assert len([x for x in sink.toposort if x.op is Ops.LOAD]) == 1
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in {"CLANG"} and AMX, "CLANG with AMX upcasts float up to size 16")
|
||||
@unittest.skipIf(Device.DEFAULT in {"CPU"} and AMX, "CPU with AMX upcasts float up to size 16")
|
||||
def test_two_load_fold(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||
load = [UOp(Ops.LOAD, dtypes.float, (buf.index(UOp.const(dtypes.int, i)),)) for i in range(8)]
|
||||
|
||||
@@ -106,11 +106,11 @@ class TestUOps(unittest.TestCase):
|
||||
self._equal(f([a,b,c], op, dts), fxn(a,b,c))
|
||||
|
||||
class TestFloatUOps(TestUOps):
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop')
|
||||
def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a))
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop')
|
||||
def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop')
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", 'not supported as uop')
|
||||
def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a))
|
||||
def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf'))
|
||||
def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
|
||||
@@ -65,12 +65,12 @@ class TestMemoryCount(unittest.TestCase):
|
||||
_, mem = get_stats(a.assign(a+a))
|
||||
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "test copy to CLANG from other device")
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
|
||||
def test_copyout(self):
|
||||
a = Tensor.empty(32, dtype=dtypes.uint8).to("CLANG")
|
||||
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")
|
||||
_, mem = get_stats(a)
|
||||
self.assertEqual(mem, 32*1)
|
||||
a = Tensor.empty(32, dtype=dtypes.uint32).to("CLANG")
|
||||
a = Tensor.empty(32, dtype=dtypes.uint32).to("CPU")
|
||||
_, mem = get_stats(a)
|
||||
self.assertEqual(mem, 32*4)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ def time_tensor_numpy(out:Tensor):
|
||||
|
||||
N = 4096
|
||||
class TestZeroCopy(unittest.TestCase):
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CLANG", "LLVM", "METAL"}, "device isn't zero copy")
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "LLVM", "METAL"}, "device isn't zero copy")
|
||||
def test_zero_copy_from_default_to_cpu(self):
|
||||
demo = Tensor.rand(1).realize()
|
||||
t1 = time_tensor_numpy(demo)
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy as np
|
||||
|
||||
class TestF16Decompression(unittest.TestCase):
|
||||
def test_u32_to_f16(self):
|
||||
a = Tensor.randn(50, dtype=dtypes.float16, device=None if is_dtype_supported(dtypes.float16) else "CLANG:0")
|
||||
a = Tensor.randn(50, dtype=dtypes.float16, device=None if is_dtype_supported(dtypes.float16) else "CPU")
|
||||
f16_as_u32 = a.bitcast(dtypes.uint32) if is_dtype_supported(dtypes.float16) else a.bitcast(dtypes.uint32).to(Device.DEFAULT)
|
||||
f16 = u32_to_f16(f16_as_u32)
|
||||
ref = a.numpy()
|
||||
|
||||
@@ -211,7 +211,7 @@ class TestSafetensors(unittest.TestCase):
|
||||
def helper_test_disk_tensor(fn, data, np_fxn, tinygrad_fxn=None):
|
||||
if tinygrad_fxn is None: tinygrad_fxn = np_fxn
|
||||
pathlib.Path(temp(fn)).unlink(missing_ok=True)
|
||||
tinygrad_tensor = Tensor(data, device="CLANG").to(f"disk:{temp(fn)}")
|
||||
tinygrad_tensor = Tensor(data, device="CPU").to(f"disk:{temp(fn)}")
|
||||
numpy_arr = np.array(data)
|
||||
tinygrad_fxn(tinygrad_tensor)
|
||||
np_fxn(numpy_arr)
|
||||
@@ -251,7 +251,7 @@ class TestDiskTensor(unittest.TestCase):
|
||||
def test_write_ones(self):
|
||||
pathlib.Path(temp("dt_write_ones")).unlink(missing_ok=True)
|
||||
|
||||
out = Tensor.ones(10, 10, device="CLANG").contiguous()
|
||||
out = Tensor.ones(10, 10, device="CPU").contiguous()
|
||||
outdisk = out.to(f"disk:{temp('dt_write_ones')}")
|
||||
print(outdisk)
|
||||
outdisk.realize()
|
||||
@@ -289,13 +289,13 @@ class TestDiskTensor(unittest.TestCase):
|
||||
def test_bitcast(self):
|
||||
with open(temp('dt_bitcast'), "wb") as f: f.write(bytes(range(10,20)))
|
||||
t = Tensor.empty(5, dtype=dtypes.int16, device=f"disk:{temp('dt_bitcast')}")
|
||||
ret = t.to("CLANG").bitcast(dtypes.uint16) + 1
|
||||
ret = t.to("CPU").bitcast(dtypes.uint16) + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369, 4883]
|
||||
|
||||
def test_bitcast_view(self):
|
||||
with open(temp('dt_bitcast_view'), "wb") as f: f.write(bytes(range(10, 24)))
|
||||
t = Tensor.empty(3, dtype=dtypes.uint, device=f"disk:{temp('dt_bitcast_view')}").shrink([(0, 2)])
|
||||
ret = t.bitcast(dtypes.uint16).to("CLANG") + 1
|
||||
ret = t.bitcast(dtypes.uint16).to("CPU") + 1
|
||||
assert ret.tolist() == [2827, 3341, 3855, 4369]
|
||||
|
||||
@unittest.skipIf(OSX, "new LLVM has an issue on OSX")
|
||||
@@ -363,10 +363,10 @@ class TestPathTensor(unittest.TestCase):
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_with_device(self):
|
||||
t = Tensor(self.test_file, device="CLANG")
|
||||
t = Tensor(self.test_file, device="CPU")
|
||||
self.assertEqual(t.shape, (100,))
|
||||
self.assertEqual(t.dtype, dtypes.uint8)
|
||||
self.assertEqual(t.device, "CLANG")
|
||||
self.assertEqual(t.device, "CPU")
|
||||
np.testing.assert_array_equal(t.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
def test_path_tensor_empty_file(self):
|
||||
@@ -391,8 +391,8 @@ class TestPathTensor(unittest.TestCase):
|
||||
|
||||
def test_path_tensor_copy_to_device(self):
|
||||
t = Tensor(self.test_file)
|
||||
t_cpu = t.to("CLANG")
|
||||
self.assertEqual(t_cpu.device, "CLANG")
|
||||
t_cpu = t.to("CPU")
|
||||
self.assertEqual(t_cpu.device, "CPU")
|
||||
np.testing.assert_array_equal(t_cpu.numpy(), np.frombuffer(self.test_data, dtype=np.uint8))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest, subprocess, platform
|
||||
from tinygrad.runtime.ops_clang import ClangJITCompiler
|
||||
from tinygrad.runtime.ops_cpu import ClangJITCompiler
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
|
||||
class TestElfLoader(unittest.TestCase):
|
||||
|
||||
@@ -91,7 +91,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||
|
||||
def test_const_view_always_valid(self):
|
||||
buf = UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CLANG"),), ShapeTracker.from_shape(())),))
|
||||
a = UOp.const(dtypes.int, 0).replace(src=(UOp(Ops.VIEW, dtypes.void, (UOp(Ops.DEVICE, arg="CPU"),), ShapeTracker.from_shape(())),))
|
||||
st = UOp.store(buf, ShapeTracker.from_shape(()).to_uop(), a.cast(dtypes.float))
|
||||
helper_test_verify_ast(st)
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class Kernel:
|
||||
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")
|
||||
|
||||
if opt.op is OptOps.LOCAL: # cyan
|
||||
# NOTE: LLVM/CLANG can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
||||
# NOTE: LLVM/CPU can use locals too, but they are treated the same as globals (still helpful for L1 cache)
|
||||
# it's disabled for now since it makes BEAM slow for little gain
|
||||
check(self.opts.has_local, "target does not support local")
|
||||
check(axis < self.global_dims, "local is for globals")
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CLANG", "LLVM", "DSP", "WEBGPU"]
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "LLVM", "DSP", "WEBGPU"]
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
|
||||
@@ -195,8 +195,8 @@ def torch_load(t:Tensor) -> dict[str, Tensor]:
|
||||
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
|
||||
intermediate_shape = tuple([shape_strides[x][0] for x in argsort(permute_indexes)])
|
||||
assert tuple([shape_strides[i][1] for i in argsort(permute_indexes)]) == strides_for_shape(intermediate_shape), "nonpermutable strides"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. CLANG to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't CLANG permute BF16"
|
||||
if DEBUG >= 3: print(f"WARNING: this torch load is slow. to permute {intermediate_shape} with {permute_indexes}")
|
||||
assert storage[1] != dtypes.bfloat16, "can't permute BF16"
|
||||
# TODO: find a nice way to support all shapetracker on disktensors
|
||||
ret = ret.to(None).reshape(intermediate_shape).permute(permute_indexes)
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class GroupOp:
|
||||
All = set(Ops)
|
||||
|
||||
# some BUFFER ops can be processed with only a view
|
||||
view_supported_devices = {"LLVM", "CLANG", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
||||
view_supported_devices = {"LLVM", "CPU", "CUDA", "NV", "AMD", "METAL", "QCOM", "DSP", "DISK"}
|
||||
|
||||
# https://en.wikipedia.org/wiki/Identity_element
|
||||
def identity_element(op:Ops, dt:DType) -> ConstType: return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt)
|
||||
|
||||
@@ -18,7 +18,7 @@ base_rewrite = PatternMatcher([
|
||||
lambda ctx,x: f"for ({ctx.render_dtype(x.dtype)} {ctx[x]} = {ctx[x.src[0]]}; {ctx[x]} < {ctx[x.src[1]]}; {ctx[x]}++) {{"),
|
||||
(UPat(Ops.VECTORIZE, name="x"),
|
||||
lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \
|
||||
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CLANG', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
|
||||
(f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device in {'CPU', 'DSP'} else f"({','.join([ctx[y] for y in x.src])})")),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x:
|
||||
f"__builtin_convertvector({ctx[x.src[0]]}, {ctx.render_dtype(x.dtype)})" if x.dtype.count > 1 and not isinstance(x.dtype, PtrDType) else None),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"),
|
||||
@@ -52,7 +52,7 @@ base_rewrite = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op](
|
||||
*([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)),
|
||||
(UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CLANG', 'DSP'} else \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device in {'CPU', 'DSP'} else \
|
||||
f".{'xyzwabcd'[x.arg[0]]}")),
|
||||
# custom passes through with format
|
||||
(UPat(Ops.CUSTOM, name="x"), lambda ctx,x: x.arg.format(*[ctx[y] for y in x.src])),
|
||||
@@ -175,7 +175,7 @@ class CStyleLanguage(Renderer):
|
||||
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
||||
|
||||
class ClangRenderer(CStyleLanguage):
|
||||
device = "CLANG"
|
||||
device = "CPU"
|
||||
float4 = "(float4)"
|
||||
has_local = False
|
||||
global_max = None
|
||||
|
||||
@@ -20,3 +20,5 @@ class ClangJITCompiler(Compiler):
|
||||
|
||||
class ClangDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, MallocAllocator, ClangRenderer(), ClangJITCompiler(), CPUProgram)
|
||||
|
||||
CPUDevice = ClangDevice
|
||||
@@ -17,7 +17,7 @@ class LLVMCompiler(Compiler):
|
||||
|
||||
triple = {'AArch64': b'aarch64', 'X86': b'x86_64'}[host_arch] + b'-none-unknown-elf'
|
||||
target = expect(llvm.LLVMGetTargetFromTriple(triple, ctypes.pointer(tgt:=llvm.LLVMTargetRef()), err:=cerr()), err, tgt)
|
||||
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_clang.py, see comments there for why it's needed on arm osx
|
||||
# +reserve-x18 here does the same thing as -ffixed-x18 in ops_cpu.py, see comments there for why it's needed on arm osx
|
||||
cpu, feats = ctypes.string_at(llvm.LLVMGetHostCPUName()), (b'+reserve-x18,' if OSX else b'') + ctypes.string_at(llvm.LLVMGetHostCPUFeatures())
|
||||
if DEBUG >= 2: print(f"LLVM init for {cpu!r} with {feats!r}")
|
||||
self.target_machine = llvm.LLVMCreateTargetMachine(target, triple, cpu, feats,
|
||||
|
||||
@@ -173,7 +173,7 @@ class PythonProgram:
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CLANG":
|
||||
elif arg[4] == "CPU":
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
@@ -194,7 +194,7 @@ class PythonRenderer(Renderer):
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
|
||||
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
|
||||
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CPU", ClangRenderer.tensor_cores
|
||||
|
||||
def render(self, uops:list[UOp]) -> str:
|
||||
lops = [(u.op, u.dtype, [uops.index(v) for v in u.src], u.arg) for u in uops]
|
||||
|
||||
@@ -274,7 +274,7 @@ class Tensor(SimpleMathTrait):
|
||||
def assign(self, x) -> Tensor:
|
||||
# TODO: this is a hack for writing to DISK. remove with working assign
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device="CLANG", dtype=self.dtype)
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
|
||||
self.contiguous().realize().lazydata.base.realized.ensure_allocated().copyin(x._data())
|
||||
return self
|
||||
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
|
||||
@@ -297,11 +297,11 @@ class Tensor(SimpleMathTrait):
|
||||
def _data(self) -> memoryview:
|
||||
if 0 in self.shape: return memoryview(bytearray(0))
|
||||
# NOTE: this realizes on the object from as_buffer being a Python object
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CLANG").realize()
|
||||
cpu = self.cast(self.dtype.base).contiguous().to("CPU").realize()
|
||||
buf = cast(UOp, cpu.lazydata).base.realized
|
||||
assert buf is not None, f"{cast(UOp, cpu.lazydata).base} was not realized"
|
||||
if self.device != "CLANG": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CLANG" else False)
|
||||
if self.device != "CPU": buf.options = BufferSpec(nolru=True)
|
||||
return buf.as_buffer(allow_zero_copy=True if self.device != "CPU" else False)
|
||||
|
||||
def data(self) -> memoryview:
|
||||
"""
|
||||
@@ -520,8 +520,8 @@ class Tensor(SimpleMathTrait):
|
||||
if (numel := prod(shape)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
num = ceildiv(numel * dtype.itemsize, 4)
|
||||
|
||||
# when using MOCKGPU and NV generate rand on CLANG
|
||||
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
|
||||
# when using MOCKGPU and NV generate rand on CPU
|
||||
if getenv("MOCKGPU") and device.startswith("NV"): device = "CPU"
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
|
||||
Reference in New Issue
Block a user