mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
ops_gpu -> ops_cl (#12103)
This commit is contained in:
54
.github/workflows/test.yml
vendored
54
.github/workflows/test.yml
vendored
@@ -330,8 +330,8 @@ jobs:
|
||||
- name: Fuzz Test shape ops
|
||||
run: python test/external/fuzz_shape_ops.py
|
||||
|
||||
testgpuimage:
|
||||
name: 'GPU IMAGE Tests'
|
||||
testopenclimage:
|
||||
name: 'CL IMAGE Tests'
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
@@ -345,15 +345,15 @@ jobs:
|
||||
key: gpu-image
|
||||
deps: testing_minimal
|
||||
opencl: 'true'
|
||||
- name: Test GPU IMAGE=2 ops + training
|
||||
- name: Test CL IMAGE=2 ops + training
|
||||
run: |
|
||||
GPU=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py --durations=20
|
||||
GPU=1 IMAGE=2 python test/models/test_end2end.py TestEnd2End.test_linear_mnist
|
||||
CL=1 IMAGE=2 python -m pytest -n=auto test/test_ops.py --durations=20
|
||||
CL=1 IMAGE=2 python test/models/test_end2end.py TestEnd2End.test_linear_mnist
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
testgpumisc:
|
||||
name: 'GPU Misc tests'
|
||||
name: 'CL Misc tests'
|
||||
runs-on: ubuntu-22.04
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
@@ -368,11 +368,11 @@ jobs:
|
||||
deps: testing_minimal
|
||||
opencl: 'true'
|
||||
- name: Generate Dataset
|
||||
run: GPU=1 extra/optimization/generate_dataset.sh
|
||||
run: CL=1 extra/optimization/generate_dataset.sh
|
||||
- name: Run Kernel Count Test
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_opt.py
|
||||
run: CL=1 python -m pytest -n=auto test/external/external_test_opt.py
|
||||
- name: Run fused optimizer tests
|
||||
run: GPU=1 FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py
|
||||
run: CL=1 FUSE_OPTIM=1 python -m pytest -n=auto test/models/test_mnist.py
|
||||
- name: Upload artifact
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -397,17 +397,17 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2175 ALLOWED_GATED_READ_IMAGE=16 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2175 ALLOWED_GATED_READ_IMAGE=16 FLOAT16=0 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot alt model correctness (float32)
|
||||
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot fastvits model correctness (float32)
|
||||
run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
# - name: Test openpilot simple_plan vision model correctness (float32)
|
||||
# run: FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
||||
# run: FLOAT16=0 DEBUGCL=1 CL=1 IMAGE=2 python examples/openpilot/compile3.py https://gitlab.com/commaai/openpilot-lfs.git/gitlab-lfs/objects/35ff4f4577002f2685e50c8346addae33fe8da27a41dd4d6a0f14d1f4b1af81b
|
||||
- name: Test openpilot LLVM compile
|
||||
run: CPU=1 CPU_LLVM=1 LLVMOPT=1 JIT=2 BEAM=0 IMAGE=0 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/9118973ed03c1ae1d40cf69a29507ec2cc78efd7/selfdrive/modeld/models/supercombo.onnx
|
||||
- name: Test openpilot compile4
|
||||
run: NOLOCALS=1 GPU=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
run: NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
@@ -459,16 +459,16 @@ jobs:
|
||||
pydeps: "tensorflow==2.15.1 tensorflow_addons"
|
||||
python-version: '3.11'
|
||||
opencl: 'true'
|
||||
- name: Test ONNX (GPU)
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test ONNX (CL)
|
||||
run: CL=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
#- name: Test Optimization Helpers
|
||||
# run: DEBUG=1 python3 extra/optimization/test_helpers.py
|
||||
#- name: Test Action Space
|
||||
# run: DEBUG=1 GPU=1 python3 extra/optimization/get_action_space.py
|
||||
# run: DEBUG=1 CL=1 python3 extra/optimization/get_action_space.py
|
||||
- name: Test Beam Search
|
||||
run: GPU=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
run: CL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
|
||||
- name: Test MLPerf stuff
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
run: CL=1 python -m pytest -n=auto test/external/external_test_optim.py test/external/external_test_losses.py test/external/external_test_metrics.py test/external/external_test_datasets.py --durations=20
|
||||
- name: Test llama 3 training
|
||||
run: MAX_BUFFER_SIZE=0 DEV=NULL SAMPLES=300 BS=8 SEQLEN=512 GRADIENT_ACC_STEPS=8 FAKEDATA=1 DEFAULT_FLOAT=bfloat16 OPTIM_DTYPE=bfloat16 LLAMA3_SIZE=1B MODEL=llama3 python3 examples/mlperf/model_train.py
|
||||
- name: Run process replay tests
|
||||
@@ -506,8 +506,8 @@ jobs:
|
||||
llvm: 'true'
|
||||
- name: Test models (llvm)
|
||||
run: CPU=1 CPU_LLVM=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (gpu)
|
||||
run: GPU=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (opencl)
|
||||
run: CL=1 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Test models (cpu)
|
||||
run: CPU=1 CPU_LLVM=0 python -m pytest -n=auto test/models --durations=20
|
||||
- name: Run process replay tests
|
||||
@@ -709,7 +709,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu, gpu]
|
||||
backend: [llvm, cpu, opencl]
|
||||
|
||||
name: Linux (${{ matrix.backend }})
|
||||
runs-on: ubuntu-22.04
|
||||
@@ -725,13 +725,13 @@ jobs:
|
||||
with:
|
||||
key: ${{ matrix.backend }}-minimal
|
||||
deps: testing_minimal
|
||||
opencl: ${{ matrix.backend == 'gpu' && 'true' }}
|
||||
opencl: ${{ matrix.backend == 'opencl' && 'true' }}
|
||||
llvm: ${{ matrix.backend == 'llvm' && 'true' }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'gpu' && 'GPU=1' }}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'CPU=1\nCPU_LLVM=1' || matrix.backend == 'cpu' && 'CPU=1\nCPU_LLVM=0\nCPU_COUNT=2' || matrix.backend == 'opencl' && 'CL=1' }}" >> $GITHUB_ENV
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
run: |
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','GPU'], Device.DEFAULT"
|
||||
python3 -c "from tinygrad import Device; assert Device.DEFAULT in ['CPU','CL'], Device.DEFAULT"
|
||||
DEBUG=5 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
|
||||
- name: Run pytest (${{ matrix.backend }})
|
||||
run: python -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit --durations=20
|
||||
@@ -772,7 +772,7 @@ jobs:
|
||||
|
||||
start_server "remote-server-amd-1" "AMD" 6667
|
||||
start_server "remote-server-amd-2" "AMD" 6668
|
||||
start_server "remote-server-gpu" "GPU" 7667
|
||||
start_server "remote-server-gpu" "CL" 7667
|
||||
start_server "remote-server-cpu" "CPU" 8667
|
||||
- name: Check Device.DEFAULT and print some source
|
||||
env:
|
||||
@@ -786,7 +786,7 @@ jobs:
|
||||
HOST: 127.0.0.1:6667*6,127.0.0.1:6668*6
|
||||
run: |
|
||||
python3 -m pytest test/test_tiny.py test/test_jit.py test/test_subbuffer.py test/test_graph.py test/test_multitensor.py test/test_remote.py test/test_tensor_variable.py --durations 20
|
||||
- name: Run REMOTE=1 Test (GPU)
|
||||
- name: Run REMOTE=1 Test (CL)
|
||||
env:
|
||||
HOST: 127.0.0.1:7667*6
|
||||
run: |
|
||||
|
||||
@@ -79,7 +79,7 @@ See [examples/beautiful_mnist.py](examples/beautiful_mnist.py) for the full vers
|
||||
|
||||
tinygrad already supports numerous accelerators, including:
|
||||
|
||||
- [x] [GPU (OpenCL)](tinygrad/runtime/ops_gpu.py)
|
||||
- [x] [OpenCL](tinygrad/runtime/ops_cl.py)
|
||||
- [x] [CPU](tinygrad/runtime/ops_cpu.py)
|
||||
- [x] [METAL](tinygrad/runtime/ops_metal.py)
|
||||
- [x] [CUDA](tinygrad/runtime/ops_cuda.py)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
This is a list of environment variable that control the runtime behavior of tinygrad and its examples.
|
||||
Most of these are self-explanatory, and are usually used to set an option at runtime.
|
||||
|
||||
Example: `GPU=1 DEBUG=4 python3 -m pytest`
|
||||
Example: `CL=1 DEBUG=4 python3 -m pytest`
|
||||
|
||||
However you can also decorate a function to set a value only inside that function.
|
||||
|
||||
@@ -31,7 +31,7 @@ These control the behavior of core tinygrad even when used as a library.
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
DEBUG | [1-7] | enable debugging output (operations, timings, speed, generated code and more)
|
||||
GPU | [1] | enable the GPU (OpenCL) backend
|
||||
CL | [1] | enable OpenCL backend
|
||||
CUDA | [1] | enable CUDA backend
|
||||
AMD | [1] | enable AMD backend
|
||||
NV | [1] | enable NV backend
|
||||
|
||||
@@ -9,7 +9,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
|
||||
| [QCOM](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_qcom.py) | Provides acceleration for QCOM GPUs | 6xx series GPUs |
|
||||
| [METAL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_metal.py) | Utilizes Metal for acceleration on Apple devices | M1+ Macs; Metal 3.0+ for `bfloat` support |
|
||||
| [CUDA](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cuda.py) | Utilizes CUDA for acceleration on NVIDIA GPUs | NVIDIA GPU with CUDA support |
|
||||
| [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device |
|
||||
| [OpenCL](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cl.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device |
|
||||
| [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
|
||||
| [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable |
|
||||
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0). |
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.schedule.kernelize import get_kernelize_map
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
# NOLOCALS=1 GPU=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
# NOLOCALS=1 CL=1 IMAGE=2 FLOAT16=1 VIZ=1 DEBUG=2 python3 examples/openpilot/compile4.py
|
||||
|
||||
OPENPILOT_MODEL = sys.argv[1] if len(sys.argv) > 1 else "https://github.com/commaai/openpilot/raw/v0.9.7/selfdrive/modeld/models/supercombo.onnx"
|
||||
OUTPUT = sys.argv[2] if len(sys.argv) > 2 else "/tmp/openpilot.pkl"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# copying the kernels from https://github.com/microsoft/ArchProbe into Python
|
||||
import numpy as np
|
||||
import pickle
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer
|
||||
from tinygrad.runtime.ops_cl import CLProgram, CLBuffer
|
||||
from tinygrad import dtypes
|
||||
from tqdm import trange, tqdm
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad import dtypes
|
||||
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
||||
from tinygrad.codegen.opt.kernel import Ops
|
||||
from tinygrad.uop.ops import BinaryOps, UnaryOps, TernaryOps
|
||||
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
||||
from tinygrad.runtime.ops_cl import ROCM_LLVM_PATH
|
||||
|
||||
# ugh, is this really needed?
|
||||
from extra.helpers import enable_early_exec
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import colored
|
||||
from extra.helpers import enable_early_exec
|
||||
early_exec = enable_early_exec()
|
||||
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
||||
from tinygrad.runtime.ops_cl import CLProgram, CLBuffer, ROCM_LLVM_PATH
|
||||
|
||||
ENABLE_NON_ASM = False
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.uop.ops import Ops
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "GPU"]
|
||||
EXPORT_SUPPORTED_DEVICE = ["WEBGPU", "CPU", "CUDA", "CL"]
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CLCompiler
|
||||
from tinygrad.runtime.ops_cl import CLProgram, CLCompiler
|
||||
from tinygrad import Device, dtypes
|
||||
from tinygrad.device import Buffer
|
||||
from hexdump import hexdump
|
||||
@@ -11,7 +11,7 @@ from hexdump import hexdump
|
||||
# https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_split_matrix_multiply_accumulate.html
|
||||
# https://hc34.hotchips.org/assets/program/conference/day1/GPU%20HPC/Intel_s%20Ponte%20Vecchio%20GPU%20-%20Architecture%20Systems%20and%20Software%20FINAL.pdf
|
||||
|
||||
device = Device["GPU"]
|
||||
device = Device["CL"]
|
||||
|
||||
# NOTE: only the subgroup type 8 ones work
|
||||
prog = CLProgram(device, "test", CLCompiler(device, "test").compile(f"""
|
||||
@@ -26,9 +26,9 @@ __kernel void test(__global float* data0, const __global int* data1, const __glo
|
||||
"""))
|
||||
#with open("/tmp/test.elf", "wb") as f: f.write(prog.lib)
|
||||
|
||||
a = Buffer("GPU", 8, dtypes.float32).allocate()
|
||||
b = Buffer("GPU", 0x10, dtypes.float16).allocate()
|
||||
c = Buffer("GPU", 8*0x10, dtypes.float16).allocate()
|
||||
a = Buffer("CL", 8, dtypes.float32).allocate()
|
||||
b = Buffer("CL", 0x10, dtypes.float16).allocate()
|
||||
c = Buffer("CL", 8*0x10, dtypes.float16).allocate()
|
||||
|
||||
row = np.array([1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8], np.float16)
|
||||
mat = np.random.random((8, 0x10)).astype(np.float16)
|
||||
|
||||
@@ -7,7 +7,7 @@ rm $LOGOPS
|
||||
test/external/process_replay/reset.py
|
||||
|
||||
CI=1 python3 -m pytest -n=auto test/test_ops.py test/test_nn.py test/test_winograd.py test/models/test_real_world.py --durations=20
|
||||
GPU=1 python3 -m pytest test/test_tiny.py
|
||||
CL=1 python3 -m pytest test/test_tiny.py
|
||||
|
||||
# extract, sort and uniq
|
||||
extra/optimization/extract_dataset.py
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import ctypes, array
|
||||
from hexdump import hexdump
|
||||
from tinygrad.runtime.ops_gpu import GPUDevice
|
||||
from tinygrad.runtime.ops_cl import CLDevice
|
||||
from tinygrad.helpers import getenv, to_mv, mv_address
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad import Tensor, TinyJit
|
||||
@@ -8,7 +8,7 @@ from tinygrad.runtime.autogen import opencl as cl
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
# create raw opencl buffer.
|
||||
gdev = GPUDevice()
|
||||
gdev = CLDevice()
|
||||
cl_buf = cl.clCreateBuffer(gdev.context, cl.CL_MEM_READ_WRITE, 0x100, None, status := ctypes.c_int32())
|
||||
assert status.value == 0
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import struct
|
||||
import json
|
||||
import traceback
|
||||
import numpy as np
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
|
||||
from tinygrad.runtime.ops_cl import CLProgram, compile_gpu
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import DEBUG, getenv
|
||||
from collections import defaultdict
|
||||
import pyopencl as cl
|
||||
from tinygrad.runtime.ops_gpu import OSX_TIMING_RATIO
|
||||
CL = Device["GPU"]
|
||||
from tinygrad.runtime.ops_cl import OSX_TIMING_RATIO
|
||||
CL = Device["CL"]
|
||||
|
||||
DEBUGCL = getenv("DEBUGCL", 0)
|
||||
FLOAT16 = getenv("FLOAT16", 0)
|
||||
@@ -110,7 +110,7 @@ class Thneed:
|
||||
prgs = {}
|
||||
for o in jdat['binaries']:
|
||||
nptr = ptr + o['length']
|
||||
prgs[o['name']] = CLProgram(Device["GPU"], o['name'], weights[ptr:nptr])
|
||||
prgs[o['name']] = CLProgram(Device["CL"], o['name'], weights[ptr:nptr])
|
||||
ptr = nptr
|
||||
|
||||
# populate the cl_cache
|
||||
@@ -267,7 +267,7 @@ class Thneed:
|
||||
for prg, args in self.cl_cache:
|
||||
events.append(prg.clprg(CL.queue, *args))
|
||||
mt = time.monotonic()
|
||||
Device["GPU"].synchronize()
|
||||
Device["CL"].synchronize()
|
||||
et = time.monotonic() - st
|
||||
print(f"submit in {(mt-st)*1000.0:.2f} ms, total runtime is {et*1000.0:.2f} ms")
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from tinygrad import Device
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.runtime.ops_gpu import CLDevice, CLAllocator, CLCompiler, CLProgram
|
||||
from tinygrad.runtime.ops_cl import CLDevice, CLAllocator, CLCompiler, CLProgram
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Runs only on OpenCL (GPU)")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Runs only on OpenCL")
|
||||
class TestCLError(unittest.TestCase):
|
||||
@unittest.skipIf(CI, "dangerous for CI, it allocates tons of memory")
|
||||
def test_oom(self):
|
||||
@@ -24,7 +24,7 @@ class TestCLError(unittest.TestCase):
|
||||
def test_unaligned_copy(self):
|
||||
data = list(range(65))
|
||||
unaligned = memoryview(bytearray(data))[1:]
|
||||
buffer = Buffer("GPU", 64, dtypes.uint8).allocate()
|
||||
buffer = Buffer("CL", 64, dtypes.uint8).allocate()
|
||||
buffer.copyin(unaligned)
|
||||
result = memoryview(bytearray(len(data) - 1))
|
||||
buffer.copyout(result)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random, os
|
||||
from tinygrad.helpers import Timing
|
||||
from tinygrad.runtime.ops_hip import compile_hip, HIPDevice
|
||||
from tinygrad.runtime.ops_gpu import compile_cl, CLDevice
|
||||
from tinygrad.runtime.ops_cl import compile_cl, CLDevice
|
||||
|
||||
# OMP_NUM_THREADS=1 strace -tt -f -e trace=file python3 test/external/external_benchmark_hip_compile.py
|
||||
# AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1 python3 test/external/external_benchmark_hip_compile.py
|
||||
|
||||
2
test/external/external_cl_half_max.py
vendored
2
test/external/external_cl_half_max.py
vendored
@@ -1,4 +1,4 @@
|
||||
from tinygrad.runtime.ops_gpu import CLDevice, CLProgram, compile_cl
|
||||
from tinygrad.runtime.ops_cl import CLDevice, CLProgram, compile_cl
|
||||
|
||||
if __name__ == "__main__":
|
||||
dev = CLDevice()
|
||||
|
||||
2
test/external/external_gpu_fail_osx.py
vendored
2
test/external/external_gpu_fail_osx.py
vendored
@@ -1,5 +1,5 @@
|
||||
# ugh, OS X OpenCL doesn't support half
|
||||
from tinygrad.runtime.ops_gpu import CLDevice, CLProgram, CLCompiler
|
||||
from tinygrad.runtime.ops_cl import CLDevice, CLProgram, CLCompiler
|
||||
|
||||
src = """#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
||||
__kernel void max_half(__global half* data0, const __global half* data1) {
|
||||
|
||||
2
test/external/external_multi_gpu.py
vendored
2
test/external/external_multi_gpu.py
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# cd extra/disassemblers/ && git clone --recursive github.com:geohot/cuda_ioctl_sniffer.git
|
||||
# LD_PRELOAD=$PWD/extra/disassemblers/cuda_ioctl_sniffer/out/sniff.so GPU=1 python3 test/external/external_multi_gpu.py
|
||||
# LD_PRELOAD=$PWD/extra/disassemblers/cuda_ioctl_sniffer/out/sniff.so CL=1 python3 test/external/external_multi_gpu.py
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import colored, Timing, getenv
|
||||
|
||||
2
test/external/external_osx_profiling.py
vendored
2
test/external/external_osx_profiling.py
vendored
@@ -1,4 +1,4 @@
|
||||
from tinygrad.runtime.ops_gpu import CLProgram, CL, CLBuffer
|
||||
from tinygrad.runtime.ops_cl import CLProgram, CL, CLBuffer
|
||||
from tinygrad import dtypes
|
||||
import time
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
2
test/external/external_test_image.py
vendored
2
test/external/external_test_image.py
vendored
@@ -4,7 +4,7 @@ import unittest
|
||||
import numpy as np
|
||||
if 'IMAGE' not in os.environ:
|
||||
os.environ['IMAGE'] = '2'
|
||||
os.environ['GPU'] = '1'
|
||||
os.environ['CL'] = '1'
|
||||
os.environ['OPT'] = '2'
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
4
test/external/external_test_onnx_backend.py
vendored
4
test/external/external_test_onnx_backend.py
vendored
@@ -193,12 +193,12 @@ backend_test.exclude('test_adam_cpu')
|
||||
backend_test.exclude('test_gradient_of_add_and_mul_cpu')
|
||||
backend_test.exclude('test_gradient_of_add_cpu')
|
||||
|
||||
if Device.DEFAULT in ['GPU', 'METAL']:
|
||||
if Device.DEFAULT in ['CL', 'METAL']:
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_2_3_cpu')
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_axes_3_2_cpu')
|
||||
backend_test.exclude('test_resize_upsample_sizes_nearest_cpu')
|
||||
|
||||
if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "GPU"):
|
||||
if Device.DEFAULT == "METAL" or (OSX and Device.DEFAULT == "CL"):
|
||||
# numerical inaccuracy
|
||||
backend_test.exclude('test_mish_cpu')
|
||||
backend_test.exclude('test_mish_expanded_cpu')
|
||||
|
||||
10
test/external/external_test_opt.py
vendored
10
test/external/external_test_opt.py
vendored
@@ -34,7 +34,7 @@ from extra.models.efficientnet import EfficientNet
|
||||
from extra.models.resnet import ResNet18
|
||||
from extra.models.vit import ViT
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Not Implemented")
|
||||
class TestInferenceMinKernels(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.training_old = Tensor.training
|
||||
@@ -90,7 +90,7 @@ class TestInferenceMinKernels(unittest.TestCase):
|
||||
with CLCache(100):
|
||||
model(inp, 0).realize()
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Not Implemented")
|
||||
class TestOptBinOp(unittest.TestCase):
|
||||
def _test_no_binop_rerun(self, f1, f2=None, allowed=1):
|
||||
a = Tensor.randn(16, 16)
|
||||
@@ -117,7 +117,7 @@ class TestOptBinOp(unittest.TestCase):
|
||||
#def test_no_binop_rerun_reduce(self): return self._test_no_binop_rerun(lambda a,b: (a*b).sum(), lambda a,b: (a*b).reshape(16, 16, 1).sum())
|
||||
#def test_no_binop_rerun_reduce_alt(self): return self._test_no_binop_rerun(lambda a,b: a.sum(1)+b[0], lambda a,b: a.sum(1).reshape(1,16)+b[0])
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Not Implemented")
|
||||
class TestOptReduceLoop(unittest.TestCase):
|
||||
def test_loop_left(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
@@ -139,7 +139,7 @@ class TestOptReduceLoop(unittest.TestCase):
|
||||
c.realize()
|
||||
assert cache.count == 2, "loop right fusion broken"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Not Implemented")
|
||||
class TestOptWChild(unittest.TestCase):
|
||||
@unittest.skip("this no longer happens, use realize")
|
||||
def test_unrealized_child(self):
|
||||
@@ -152,7 +152,7 @@ class TestOptWChild(unittest.TestCase):
|
||||
d.realize()
|
||||
assert cache.count == 2, "don't fuse if you have children"
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
@unittest.skipUnless(Device.DEFAULT == "CL", "Not Implemented")
|
||||
class TestOpt(unittest.TestCase):
|
||||
def test_muladd(self):
|
||||
a,b,c = [Tensor.randn(2,2).realize() for _ in range(3)]
|
||||
|
||||
8
test/external/fuzz_linearizer.py
vendored
8
test/external/fuzz_linearizer.py
vendored
@@ -16,7 +16,7 @@ if os.getenv("VALIDATE_HCQ", 0) != 0:
|
||||
try:
|
||||
import extra.qcom_gpu_driver.opencl_ioctl
|
||||
from tinygrad import Device
|
||||
_, _ = Device["QCOM"], Device["GPU"]
|
||||
_, _ = Device["QCOM"], Device["CL"]
|
||||
except Exception: pass
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
@@ -42,9 +42,9 @@ if getenv("VALIDATE_HCQ"):
|
||||
on_linearizer_did_run = extra.nv_gpu_driver.nv_ioctl.collect_last_launch_state
|
||||
compare_states = extra.nv_gpu_driver.nv_ioctl.compare_launch_state
|
||||
elif Device.DEFAULT == "QCOM":
|
||||
print("VALIDATE_HCQ: Comparing QCOM to GPU")
|
||||
print("VALIDATE_HCQ: Comparing QCOM to CL")
|
||||
import extra.qcom_gpu_driver.opencl_ioctl
|
||||
validate_device = Device["GPU"]
|
||||
validate_device = Device["CL"]
|
||||
on_linearizer_will_run = extra.qcom_gpu_driver.opencl_ioctl.before_launch
|
||||
on_linearizer_did_run = extra.qcom_gpu_driver.opencl_ioctl.collect_last_launch_state
|
||||
compare_states = extra.qcom_gpu_driver.opencl_ioctl.compare_launch_state
|
||||
@@ -302,7 +302,7 @@ if __name__ == "__main__":
|
||||
for i, ast in enumerate(ast_strs[:getenv("FUZZ_N", len(ast_strs))]):
|
||||
if (nth := getenv("FUZZ_NTH", -1)) != -1 and i != nth: continue
|
||||
if getenv("FUZZ_IMAGEONLY") and "dtypes.image" not in ast: continue
|
||||
if "dtypes.image" in ast and Device.DEFAULT not in {"GPU", "QCOM"}: continue # IMAGE is only for GPU
|
||||
if "dtypes.image" in ast and Device.DEFAULT not in {"CL", "QCOM"}: continue # IMAGE is only for CL
|
||||
if ast in seen_ast_strs: continue
|
||||
seen_ast_strs.add(ast)
|
||||
|
||||
|
||||
@@ -57,8 +57,8 @@ def eval_uop(uop:UOp, inputs:list[tuple[DType, list[Any]]]|None=None):
|
||||
return out_buf.cast(uop.dtype.fmt).tolist()[0]
|
||||
|
||||
def not_support_multi_device():
|
||||
# GPU and CUDA don't support multi device if in CI
|
||||
return CI and REAL_DEV in ("GPU", "CUDA")
|
||||
# CL and CUDA don't support multi device if in CI
|
||||
return CI and REAL_DEV in ("CL", "CUDA")
|
||||
|
||||
# NOTE: This will open REMOTE if it's the default device
|
||||
REAL_DEV = (Device.DEFAULT if Device.DEFAULT != "REMOTE" else Device['REMOTE'].properties.real_device)
|
||||
|
||||
@@ -114,7 +114,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU"}, "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
with Tensor.train():
|
||||
model = SpeedyResNet(Tensor.ones((12,3,2,2)))
|
||||
|
||||
@@ -27,7 +27,7 @@ def train_one_step(model,X,Y):
|
||||
print("done in %.2f ms" % (et*1000.))
|
||||
|
||||
def check_gc():
|
||||
if Device.DEFAULT == "GPU":
|
||||
if Device.DEFAULT == "CL":
|
||||
from extra.introspection import print_objects
|
||||
assert print_objects() == 0
|
||||
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
a = Tensor.rand(8, N, 8, N)
|
||||
r = a.sum(axis=(1,3))
|
||||
helper_linearizer_opt(r, [
|
||||
# openCL / GPU=1 is 256 max threads
|
||||
# openCL / CL=1 is 256 max threads
|
||||
[Opt(OptOps.GROUPTOP, 0, 2)], [Opt(OptOps.GROUPTOP, 0, 32)],
|
||||
[Opt(OptOps.GROUPTOP, 1, 2)], [Opt(OptOps.GROUPTOP, 1, 32)], # Checking how it works with 1 grouped_reduce.
|
||||
[Opt(OptOps.GROUPTOP, 0, 2), Opt(OptOps.GROUPTOP, 1, 2)],
|
||||
|
||||
@@ -77,9 +77,9 @@ class TestCopySpeed(unittest.TestCase):
|
||||
np.testing.assert_equal(t.numpy(), x.numpy())
|
||||
|
||||
@unittest.skipIf(CI, "CI doesn't have 6 GPUs")
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "only test this on GPU")
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "only test this on CL")
|
||||
def testCopyCPUto6GPUs(self):
|
||||
from tinygrad.runtime.ops_gpu import CLDevice
|
||||
from tinygrad.runtime.ops_cl import CLDevice
|
||||
if len(CLDevice.device_ids) != 6: raise unittest.SkipTest("computer doesn't have 6 GPUs")
|
||||
t = Tensor.ones(N, N, device="CPU").contiguous().realize()
|
||||
print(f"buffer: {t.nbytes()*1e-9:.2f} GB")
|
||||
@@ -87,8 +87,8 @@ class TestCopySpeed(unittest.TestCase):
|
||||
with Timing("sync: ", on_exit=lambda ns: f" @ {t.nbytes()/ns:.2f} GB/s ({t.nbytes()*6/ns:.2f} GB/s total)"):
|
||||
with Timing("queue: "):
|
||||
for g in range(6):
|
||||
t.to(f"gpu:{g}").realize()
|
||||
Device["gpu"].synchronize()
|
||||
t.to(f"CL:{g}").realize()
|
||||
Device["CL"].synchronize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -424,7 +424,7 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
# TODO: wrong output with GPU=1 on mac
|
||||
# TODO: wrong output with CL=1 on mac
|
||||
data = [60000.0, 70000.0, 80000.0]
|
||||
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.engine.realize import lower_schedule
|
||||
from tinygrad.helpers import prod, unwrap
|
||||
from test.helpers import REAL_DEV
|
||||
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "GPU")
|
||||
IMAGE_SUPPORTED_DEVICES = ("QCOM", "CL")
|
||||
|
||||
@unittest.skipUnless(REAL_DEV in IMAGE_SUPPORTED_DEVICES, "Images not supported")
|
||||
class TestImageCopy(unittest.TestCase):
|
||||
|
||||
@@ -1304,7 +1304,7 @@ class TestOps(unittest.TestCase):
|
||||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "GPU", "CUDA"] or (Device.DEFAULT == "CPU" and CPU_LLVM) or IMAGE
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "CL", "CUDA"] or (Device.DEFAULT == "CPU" and CPU_LLVM) or IMAGE
|
||||
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
|
||||
def test_gemm_fp16(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestOpts(unittest.TestCase):
|
||||
out = (a+b).contiguous(arg=opts)
|
||||
s = out.schedule()
|
||||
self.assertEqual(s[-1].ast.arg.opts_to_apply, opts)
|
||||
if Device.DEFAULT in {"CPU", "GPU", "METAL"} and not CPU_LLVM:
|
||||
if Device.DEFAULT in {"CPU", "CL", "METAL"} and not CPU_LLVM:
|
||||
prg = get_program(s[-1].ast)
|
||||
self.assertIn('float4', prg.src)
|
||||
|
||||
|
||||
@@ -1654,7 +1654,7 @@ class TestSchedule(unittest.TestCase):
|
||||
constv = Tensor.empty(2, 2).uop.const_like(10).contiguous()
|
||||
check_schedule(constv, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
def test_image_matmul(self):
|
||||
with Context(IMAGE=2):
|
||||
x = Tensor.randn((9, 9)).realize()
|
||||
|
||||
@@ -137,7 +137,7 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
# *** image ***
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
@unittest.skipIf(Device.DEFAULT != "CL", "image only supported on CL")
|
||||
def test_image(self):
|
||||
with Context(IMAGE=2): self.test_gemm(N=4, out_dtype=dtypes.imagef((4, 1, 4)))
|
||||
|
||||
|
||||
@@ -513,7 +513,7 @@ class TestUOpStr(unittest.TestCase):
|
||||
assert str(eval(str(vec))) == str(vec)
|
||||
|
||||
def test_device_arg(self):
|
||||
device = UOp(Ops.DEVICE, arg="GPU")
|
||||
device = UOp(Ops.DEVICE, arg="CL")
|
||||
assert str(eval(str(device))) == str(device)
|
||||
|
||||
def test_reduceop_arg(self):
|
||||
|
||||
@@ -9,12 +9,12 @@ class TestDevice(unittest.TestCase):
|
||||
self.assertEqual(Device.canonicalize(None), Device.DEFAULT)
|
||||
self.assertEqual(Device.canonicalize("CPU"), "CPU")
|
||||
self.assertEqual(Device.canonicalize("cpu"), "CPU")
|
||||
self.assertEqual(Device.canonicalize("GPU"), "GPU")
|
||||
self.assertEqual(Device.canonicalize("GPU:0"), "GPU")
|
||||
self.assertEqual(Device.canonicalize("gpu:0"), "GPU")
|
||||
self.assertEqual(Device.canonicalize("GPU:1"), "GPU:1")
|
||||
self.assertEqual(Device.canonicalize("gpu:1"), "GPU:1")
|
||||
self.assertEqual(Device.canonicalize("GPU:2"), "GPU:2")
|
||||
self.assertEqual(Device.canonicalize("CL"), "CL")
|
||||
self.assertEqual(Device.canonicalize("CL:0"), "CL")
|
||||
self.assertEqual(Device.canonicalize("cl:0"), "CL")
|
||||
self.assertEqual(Device.canonicalize("CL:1"), "CL:1")
|
||||
self.assertEqual(Device.canonicalize("cl:1"), "CL:1")
|
||||
self.assertEqual(Device.canonicalize("CL:2"), "CL:2")
|
||||
self.assertEqual(Device.canonicalize("disk:/dev/shm/test"), "DISK:/dev/shm/test")
|
||||
self.assertEqual(Device.canonicalize("disk:000.txt"), "DISK:000.txt")
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class TestIndexing(unittest.TestCase):
|
||||
# self.assertRaises(TypeError, delitem)
|
||||
|
||||
# TODO: LLVM is quite fast, why are other compiled backends slow?
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU", "GPU", "METAL", "NV", "AMD"], "slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["CPU", "CL", "METAL", "NV", "AMD"], "slow")
|
||||
def test_advancedindex(self):
|
||||
# integer array indexing
|
||||
|
||||
|
||||
@@ -359,7 +359,7 @@ class TestImageSimplification(unittest.TestCase):
|
||||
self.check(load, None, "((gidx*3)+-1438)", "0")
|
||||
|
||||
def test_simplify2(self):
|
||||
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
# from CL=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
lidx = Special("lidx", 4)
|
||||
valid = (lidx<3) & (lidx<1).ne(True)
|
||||
idx = ((lidx+1)%2, (lidx+1)//2-1)
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.renderer import Renderer
|
||||
|
||||
# **************** Device ****************
|
||||
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "GPU", "CPU", "DSP", "WEBGPU"]
|
||||
ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "CL", "CPU", "DSP", "WEBGPU"]
|
||||
class _Device:
|
||||
def __init__(self) -> None:
|
||||
self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
|
||||
@@ -336,11 +336,11 @@ def is_dtype_supported(dtype:DType, device:str|None=None) -> bool:
|
||||
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
|
||||
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
|
||||
if dtype == dtypes.half:
|
||||
if device == "GPU": return not CI and not OSX
|
||||
if device == "CL": return not CI and not OSX
|
||||
if device in ["CUDA", "NV"]: return not CI
|
||||
if device == "CPU" and CPU_LLVM: return OSX
|
||||
if device == "PYTHON": return sys.version_info >= (3, 12)
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "CL")
|
||||
return True
|
||||
|
||||
if PROFILE:
|
||||
|
||||
@@ -242,7 +242,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
return defines + "\n" + self._render_body(function_name, kernel, bufs, uops, prefix) + "\n" + self._render_entry(function_name, bufs)
|
||||
|
||||
class OpenCLRenderer(CStyleLanguage):
|
||||
device = "GPU"
|
||||
device = "CL"
|
||||
|
||||
# language options
|
||||
kernel_typedef = "__kernel void"
|
||||
@@ -271,7 +271,7 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_typedef = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
||||
device, suffix, kernel_typedef = "CL", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel void"
|
||||
tensor_cores = tc.intel
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.device import BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQBuffer, HWQueue, HCQProgram, HCQCompiled, HCQAllocatorBase, HCQSignal, HCQArgsState, BumpAllocator
|
||||
from tinygrad.runtime.support.hcq import FileIOInterface, MMIOInterface
|
||||
from tinygrad.runtime.autogen import kgsl, adreno
|
||||
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
|
||||
from tinygrad.runtime.ops_cl import CLCompiler, CLDevice
|
||||
from tinygrad.renderer.cstyle import QCOMRenderer
|
||||
from tinygrad.helpers import getenv, mv_address, to_mv, round_up, data64_le, prod, fromimport
|
||||
if getenv("IOCTL"): import extra.qcom_gpu_driver.opencl_ioctl # noqa: F401 # pylint: disable=unused-import
|
||||
|
||||
Reference in New Issue
Block a user