mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
70
.github/workflows/benchmark.yml
vendored
70
.github/workflows/benchmark.yml
vendored
@@ -47,55 +47,55 @@ jobs:
|
||||
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
run: python3.11 test/external/process_replay/reset.py
|
||||
- name: Run Stable Diffusion
|
||||
run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
run: JIT=2 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run Stable Diffusion with fp16
|
||||
run: JIT=2 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
|
||||
run: JIT=2 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
|
||||
- name: Run SDXL
|
||||
run: JIT=2 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: JIT=2 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
- name: Run model inference benchmark
|
||||
run: METAL=1 python3 test/external/external_model_benchmark.py
|
||||
run: METAL=1 python3.11 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
- name: Test tensor cores
|
||||
run: METAL=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Run Tensor Core GEMM
|
||||
run: |
|
||||
DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt
|
||||
DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
|
||||
DEBUG=2 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
|
||||
- name: Fuzz Padded Tensor Core GEMM
|
||||
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
- name: Run LLaMA with BEAM
|
||||
run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
|
||||
- name: Run quantized LLaMA
|
||||
run: |
|
||||
python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
|
||||
python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
|
||||
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
|
||||
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
|
||||
- name: Run LLaMA 7B on 4 (virtual) GPUs
|
||||
run: python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||
JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||
JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
|
||||
JIT=1 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
|
||||
- name: Run GPT2 w HALF
|
||||
run: HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
run: JIT=2 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
run: JIT=2 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
|
||||
- name: Run 10 CIFAR training steps w HALF
|
||||
run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
|
||||
#- name: Run 10 CIFAR training steps w BF16
|
||||
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
|
||||
- name: Run 10 CIFAR training steps w winograd
|
||||
run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
run: JIT=2 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (Mac)
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
train_cifar_bf16.txt
|
||||
train_cifar_wino.txt
|
||||
- name: Run process replay tests
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3.11 process_replay.py
|
||||
|
||||
testnvidiabenchmark:
|
||||
name: tinybox green Benchmark
|
||||
@@ -157,7 +157,7 @@ jobs:
|
||||
- name: Run model inference benchmark
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 NOCLANG=1 python3 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
run: NV=1 RUN_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
|
||||
- name: Test tensor cores
|
||||
run: |
|
||||
NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
@@ -170,8 +170,10 @@ jobs:
|
||||
run: NV=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
|
||||
- name: Run Tensor Core GEMM (NV)
|
||||
run: NV=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
|
||||
- name: Run Tensor Core GEMM (NV) with BEAM
|
||||
run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
# - name: Run Tensor Core GEMM (NV) with BEAM
|
||||
# run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
|
||||
- name: Test speed vs theoretical
|
||||
run: NV=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py
|
||||
- name: Run Stable Diffusion
|
||||
run: NV=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
@@ -328,6 +330,10 @@ jobs:
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: setup perflevel
|
||||
run: |
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
|
||||
rocm-smi
|
||||
- name: Show off tinybox
|
||||
run: /opt/rocm/bin/rocm-bandwidth-test
|
||||
# TODO: unstable on AMD
|
||||
@@ -343,6 +349,8 @@ jobs:
|
||||
AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Run Tensor Core GEMM (AMD)
|
||||
run: AMD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
|
||||
- name: Test speed vs theoretical
|
||||
run: AMD=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py
|
||||
# TODO: AMD compiler bug causes this to fail
|
||||
#- name: Fuzz Padded Tensor Core GEMM
|
||||
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
@@ -431,6 +439,10 @@ jobs:
|
||||
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
|
||||
- name: reset process replay
|
||||
run: test/external/process_replay/reset.py
|
||||
- name: setup perflevel
|
||||
run: |
|
||||
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
|
||||
rocm-smi
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
|
||||
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -209,15 +209,15 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Python 3.8
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
python-version: "3.10"
|
||||
- name: Cache python packages
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
|
||||
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.10/site-packages
|
||||
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.10
|
||||
- name: Install dependencies
|
||||
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Lint bad-indentation and trailing-whitespace with pylint
|
||||
|
||||
@@ -46,4 +46,5 @@ FLOAT16 | [1] | use float16 for images instead of float32
|
||||
PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend.
|
||||
PROFILE | [1] | enable output of [perfetto](https://ui.perfetto.dev/) compatible profile. This feature is supported in NV and AMD backends.
|
||||
VISIBLE_DEVICES | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0).
|
||||
JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled
|
||||
JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled
|
||||
VIZ | [1] | 0=disabled, 1=[viz enabled](../tinygrad/viz/README)
|
||||
@@ -5,6 +5,7 @@
|
||||
::: tinygrad.Tensor.ones
|
||||
::: tinygrad.Tensor.full
|
||||
::: tinygrad.Tensor.arange
|
||||
::: tinygrad.Tensor.linspace
|
||||
::: tinygrad.Tensor.eye
|
||||
::: tinygrad.Tensor.full_like
|
||||
::: tinygrad.Tensor.zeros_like
|
||||
|
||||
@@ -178,7 +178,7 @@ class TrackedMemoryView:
|
||||
def __len__(self): return len(self.mv)
|
||||
def __repr__(self): return repr(self.mv)
|
||||
|
||||
def _memoryview(mem):
|
||||
def _memoryview(cls, mem):
|
||||
if isinstance(mem, int) or isinstance(mem, ctypes.Array):
|
||||
addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
|
||||
for d in drivers:
|
||||
@@ -196,7 +196,7 @@ install_hook(libc.lseek64, _lseek64)
|
||||
install_hook(libc.stat64, _stat64)
|
||||
install_hook(libc.fstat64, _fstat64)
|
||||
install_hook(libc.getdents64, _getdents64)
|
||||
builtins.memoryview = _memoryview # type: ignore
|
||||
builtins.memoryview = type("memoryview", (), {'__new__': _memoryview}) # type: ignore
|
||||
|
||||
# rewrite autogen's libc mmaps functions.
|
||||
import tinygrad.runtime.autogen.libc as autogen_libc
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
|
||||
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from typing import Optional, Union, List, Any, Tuple
|
||||
import math
|
||||
|
||||
@@ -9,7 +9,8 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
|
||||
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
|
||||
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
|
||||
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
|
||||
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
|
||||
|
||||
class ResBlock:
|
||||
def __init__(self, channels:int, emb_channels:int, out_channels:int):
|
||||
@@ -222,16 +223,17 @@ class UNetModel:
|
||||
]
|
||||
|
||||
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
|
||||
t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
|
||||
t_emb = timestep_embedding(tms, self.model_ch)
|
||||
emb = t_emb.sequential(self.time_embed)
|
||||
|
||||
if y is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + y.sequential(self.label_emb[0])
|
||||
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
if is_dtype_supported(dtypes.float16):
|
||||
emb = emb.cast(dtypes.float16)
|
||||
ctx = ctx.cast(dtypes.float16)
|
||||
x = x .cast(dtypes.float16)
|
||||
|
||||
def run(x:Tensor, bb) -> Tensor:
|
||||
if isinstance(bb, ResBlock): x = bb(x, emb)
|
||||
|
||||
@@ -7,6 +7,7 @@ from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.helpers import getenv, DEBUG, CI, OSX
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
|
||||
try:
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
@@ -29,14 +30,6 @@ def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Uni
|
||||
cache_misses = info.misses
|
||||
return ret
|
||||
|
||||
# copied from helpers.py
|
||||
def is_dtype_supported(dtype, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.bfloat16: return False
|
||||
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
|
||||
if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
|
||||
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
|
||||
return True
|
||||
|
||||
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
|
||||
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
|
||||
# TODO: use dtypes.float16 for FLOAT16
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# type: ignore
|
||||
import ctypes, ctypes.util, struct, fcntl, re
|
||||
from hexdump import hexdump
|
||||
from copy import deepcopy
|
||||
import pathlib, sys
|
||||
from tinygrad.helpers import to_mv, getenv
|
||||
from tinygrad.runtime.autogen import adreno
|
||||
@@ -91,27 +92,36 @@ def parse_cmd_buf(dat):
|
||||
num_unit = vals[0]>>22
|
||||
if IOCTL > 0: print(f"{num_unit=} {state_block=} {state_src=} {state_type=} {dst_off=}")
|
||||
|
||||
if state_block == SB6_CS_SHADER and IOCTL > 2:
|
||||
if "LOAD_FRAGS" not in CAPTURED_STATE: CAPTURED_STATE['LOAD_FRAGS'] = []
|
||||
CAPTURED_STATE['LOAD_FRAGS'].append((state_block, state_type, num_unit, dst_off))
|
||||
|
||||
if state_block == SB6_CS_SHADER:
|
||||
from extra.disassemblers.adreno import disasm_raw
|
||||
if state_type == ST6_SHADER: disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128))
|
||||
if state_type == ST6_CONSTANTS: hexdump(get_mem(((vals[2] << 32) | vals[1]), min(0x180, num_unit*4)))
|
||||
if state_type == ST6_SHADER and IOCTL > 2:
|
||||
disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128))
|
||||
if state_type == ST6_CONSTANTS:
|
||||
x = get_mem(((vals[2] << 32) | vals[1]), num_unit*4)
|
||||
CAPTURED_STATE['constants'] = x[:]
|
||||
if IOCTL > 2:
|
||||
print('constants')
|
||||
hexdump(x)
|
||||
if state_type == ST6_IBO:
|
||||
ibos_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4)
|
||||
CAPTURED_STATE['ibos'] = ibos_bytes[:]
|
||||
if IOCTL > 0:
|
||||
if IOCTL > 1:
|
||||
print('texture ibos')
|
||||
hexdump(ibos_bytes)
|
||||
elif state_block == SB6_CS_TEX and IOCTL > 2:
|
||||
elif state_block == SB6_CS_TEX:
|
||||
if state_type == ST6_SHADER:
|
||||
samplers_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 4 * 4)
|
||||
CAPTURED_STATE['samplers'] = ibos_bytes[:]
|
||||
if IOCTL > 0:
|
||||
CAPTURED_STATE['samplers'] = samplers_bytes[:]
|
||||
if IOCTL > 1:
|
||||
print('texture samplers')
|
||||
hexdump(samplers_bytes)
|
||||
if state_type == ST6_CONSTANTS:
|
||||
descriptors_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4)
|
||||
CAPTURED_STATE['descriptors'] = ibos_bytes[:]
|
||||
if IOCTL > 0:
|
||||
descriptors_bytes = get_mem((vals[2] << 32) | vals[1], 1600)
|
||||
CAPTURED_STATE['descriptors'] = descriptors_bytes[:]
|
||||
if IOCTL > 1:
|
||||
print('texture descriptors')
|
||||
hexdump(descriptors_bytes)
|
||||
|
||||
@@ -200,10 +210,19 @@ install_hook(libc.ioctl, ioctl)
|
||||
def before_launch():
|
||||
global CAPTURED_STATE
|
||||
CAPTURED_STATE.clear()
|
||||
def collect_last_launch_state(): return CAPTURED_STATE
|
||||
def collect_last_launch_state():
|
||||
global CAPTURED_STATE
|
||||
return deepcopy(CAPTURED_STATE)
|
||||
def compare_launch_state(state, good_state):
|
||||
cmp = [
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, 0xffffffff),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NTEX__MASK),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NSAMP__MASK),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NIBO__MASK),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_ENABLED),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_TEX),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_SAMP),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_IBO),
|
||||
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_UBO),
|
||||
|
||||
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_HALFREGFOOTPRINT__MASK),
|
||||
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_FULLREGFOOTPRINT__MASK),
|
||||
@@ -213,20 +232,53 @@ def compare_launch_state(state, good_state):
|
||||
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_EARLYPREAMBLE),
|
||||
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_MERGEDREGS),
|
||||
|
||||
(adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_MEMSIZEPERITEM__MASK),
|
||||
(adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_HWSTACKSIZEPERTHREAD__MASK),
|
||||
|
||||
(adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK5),
|
||||
(adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK6),
|
||||
|
||||
(adreno.REG_A6XX_SP_CS_BRANCH_COND, 0xffffffff),
|
||||
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_KERNELDIM__MASK),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEX__MASK),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEY__MASK),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEZ__MASK),
|
||||
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_1, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_2, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_3, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_4, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_5, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_6, 0xffffffff),
|
||||
|
||||
(adreno.REG_A6XX_HLSQ_CS_CNTL_0, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_CNTL_1, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_X, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Y, 0xffffffff),
|
||||
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Z, 0xffffffff),
|
||||
]
|
||||
|
||||
for x,m in cmp:
|
||||
print(f"Field {REGS[x]}, mask: 0x{m:X} cmp: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}")
|
||||
if state.get(x, 0) & m != good_state.get(x, 0) & m:
|
||||
return False, f"Field {REGS[x]}, mask: {x:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}"
|
||||
return False, f"Field {REGS[x]}, mask: 0x{m:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}"
|
||||
|
||||
for n in ['ibos', 'samplers', 'descriptors']:
|
||||
for n in ['descriptors', 'ibos']:
|
||||
if n not in good_state: continue
|
||||
mv1, mv2 = state.get(n), good_state.get(n)
|
||||
if len(mv1) != len(mv2): return False, f"{n}: len mismatch"
|
||||
|
||||
if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}"
|
||||
mv1 = memoryview(bytearray(mv1)).cast('I')
|
||||
mv2 = memoryview(bytearray(mv2)).cast('I')
|
||||
for i in range(len(mv2)):
|
||||
if i % 8 == 5 or i % 8 == 4: continue # addresses
|
||||
if mv1[i]!=mv2[i]: return False, f"{n}: content mismatch {i} {mv1[i]} {mv2[i]}"
|
||||
|
||||
for n in ['samplers']:
|
||||
if n not in good_state: continue
|
||||
mv1, mv2 = state.get(n), good_state.get(n)
|
||||
if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}"
|
||||
if any(mv1[i]!=mv2[i] for i in range(len(mv1))): return False, f"{n}: content mismatch"
|
||||
|
||||
return True, "PASS"
|
||||
|
||||
2
setup.py
2
setup.py
@@ -22,7 +22,7 @@ setup(name='tinygrad',
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=[],
|
||||
python_requires='>=3.8',
|
||||
python_requires='>=3.10',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
'arm': ["unicorn"],
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
FROM ubuntu:20.04
|
||||
FROM ubuntu:22.04
|
||||
|
||||
# Install python3.8, and pip3
|
||||
# Install python3.10, and pip3
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.8 \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
62
test/external/external_test_hcq_fuzz_failures.py
vendored
Normal file
62
test/external/external_test_hcq_fuzz_failures.py
vendored
Normal file
File diff suppressed because one or more lines are too long
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -30,7 +30,7 @@ from tinygrad.device import is_dtype_supported
|
||||
|
||||
def on_linearizer_will_run(): pass
|
||||
def on_linearizer_did_run(): pass
|
||||
def compare_states(x, y): return True
|
||||
def compare_states(x, y): return (True, "")
|
||||
|
||||
if getenv("VALIDATE_HCQ"):
|
||||
if Device.DEFAULT == "NV":
|
||||
|
||||
51
test/external/speed_v_theoretical.py
vendored
Normal file
51
test/external/speed_v_theoretical.py
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
import unittest, time
|
||||
from tinygrad import Tensor, TinyJit, Device
|
||||
from tinygrad.helpers import Context, DEBUG
|
||||
|
||||
class TestKernelSpeed(unittest.TestCase):
|
||||
def _test_matmul(self, M, N=None, K=None, nv=None, amd=None):
|
||||
# (MxK) @ (KxN)
|
||||
@TinyJit
|
||||
def f(a, b): return (a @ b).realize()
|
||||
|
||||
if N is None: N = M
|
||||
if K is None: K = M
|
||||
tms = []
|
||||
with Context(BEAM=3):
|
||||
for _ in range(10):
|
||||
with Context(BEAM=0, DEBUG=0):
|
||||
a = Tensor.rand(M, K, dtype="half").realize()
|
||||
b = Tensor.rand(K, N, dtype="half").realize()
|
||||
Device.default.synchronize()
|
||||
st = time.perf_counter()
|
||||
_c = f(a, b)
|
||||
Device.default.synchronize()
|
||||
tms.append(time.perf_counter() - st)
|
||||
|
||||
ops = 2 * M * N * K
|
||||
tm = min(tms)
|
||||
tflops = ops / tm / 1e12
|
||||
|
||||
if DEBUG >= 1:
|
||||
print(f"{tm=}")
|
||||
print(f"{tflops=}")
|
||||
|
||||
if Device.DEFAULT == "NV":
|
||||
if DEBUG >=1: print(f"target: {nv}")
|
||||
self.assertGreater(tflops, nv)
|
||||
if Device.DEFAULT == "AMD":
|
||||
if DEBUG >=1: print(f'target: {amd}')
|
||||
self.assertGreater(tflops, amd)
|
||||
|
||||
# TODO: smaller ones has other overhead in synchronize
|
||||
# TODO: AMD number can be better (perf level?)
|
||||
def test_gemm_1024(self): self._test_matmul(1024, nv=8, amd=7)
|
||||
def test_gemm_2048(self): self._test_matmul(2048, nv=50, amd=30)
|
||||
def test_gemm_4096(self): self._test_matmul(4096, nv=95, amd=65)
|
||||
def test_gemm_8192(self): self._test_matmul(8192, nv=130, amd=70)
|
||||
|
||||
# TODO: add gemv, which is memory bounded
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.empty((4,4))
|
||||
b = Tensor.empty((4,4))
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertIn(len(if_uops), {1,3})
|
||||
self.assertIn(len(if_uops), {1,2,3})
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
self.assertLessEqual(len(conditions), 9)
|
||||
|
||||
|
||||
@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
assert k is not None
|
||||
ifs = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertEqual(len(ifs), 4)
|
||||
self.assertEqual(len(ifs), 3)
|
||||
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
|
||||
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)
|
||||
|
||||
|
||||
@@ -233,6 +233,20 @@ class TestOps(unittest.TestCase):
|
||||
def test_arange_4096(self):
|
||||
helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True)
|
||||
|
||||
def test_linspace(self):
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype="int32"), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, 10, 20, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 20, dtype="int32"), forward_only=True)
|
||||
helper_test_op([], lambda: torch.linspace(5, -5, 20, dtype=torch.int32), lambda: Tensor.linspace(5, -5, 20, dtype="int32"), forward_only=True)
|
||||
self.helper_test_exception([], lambda: torch.linspace(5, 10, 3, dtype=torch.bool), lambda: Tensor.linspace(5, 10, 3, dtype="bool"),
|
||||
expected=(RuntimeError, ValueError))
|
||||
self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError))
|
||||
|
||||
def test_sum_fake(self):
|
||||
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))
|
||||
|
||||
@@ -1483,6 +1497,11 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
|
||||
|
||||
def test_simple_conv2d_bias(self):
|
||||
helper_test_op([(1,4,9,9), (4,4,3,3), (4,)],
|
||||
lambda x,w,b: torch.nn.functional.conv2d(x,w,b).relu(),
|
||||
lambda x,w,b: Tensor.conv2d(x,w,b).relu(), grad_rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no conv3d on images")
|
||||
def test_simple_conv3d(self):
|
||||
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.engine.realize import lower_schedule, CompiledRunner
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
@@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
counts = Tensor.arange(20, dtype=dtypes.uint32)
|
||||
counts0, counts1 = counts.chunk(2)
|
||||
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
|
||||
r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy()
|
||||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
def test_threefry_doesnt_use_long(self):
|
||||
for ei in lower_schedule(Tensor.rand(20).schedule()):
|
||||
if isinstance(ei.prg, CompiledRunner):
|
||||
for u in ei.prg.p.uops:
|
||||
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
|
||||
|
||||
def test_threefry_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
@@ -227,11 +234,15 @@ class TestRandomness(unittest.TestCase):
|
||||
|
||||
def test_randint(self):
|
||||
self.assertFalse(normal_test(Tensor.randint))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5),
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"),
|
||||
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
|
||||
self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
|
||||
# check types of args
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float")
|
||||
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
|
||||
|
||||
def test_normal(self):
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from test.helpers import ast_const, timeit
|
||||
@@ -1626,13 +1626,14 @@ class TestIndexing(unittest.TestCase):
|
||||
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
|
||||
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_recursive_st_fixup(self):
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
|
||||
swizzle = ast.src[0].src[2].reshape((4, 1))
|
||||
new_uop = swizzle_rewrite(swizzle)
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertLess(et, 1e3)
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
def test_strongly_connected_DAG(self):
|
||||
val = 1.0
|
||||
@@ -1864,5 +1865,47 @@ class TestSwizzle(unittest.TestCase):
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_swizzle_failure_permute(self):
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(20, ('METAL', 65, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.ADD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 2925, dtypes.float)), src=()),
|
||||
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
|
||||
x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12,
|
||||
UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()),
|
||||
x15,)),)),
|
||||
x6,)),)),
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
|
||||
x12,
|
||||
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
|
||||
x15,)),
|
||||
UOp(Ops.MUL, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 2925, dtypes.float)), src=()),
|
||||
x10,)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(4, ('METAL', 2925, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
|
||||
ret = swizzle_rewrite(sink)
|
||||
self.assertEqual(swizzle_cnt(ret), 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -38,6 +38,7 @@ def helper_test_speed(f1, *args):
|
||||
|
||||
# operation cache defeats
|
||||
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
|
||||
args = [(x-1).realize() if isinstance(x, Tensor) else (None if x is None else (x-1)) for x in args]
|
||||
|
||||
# force syncing
|
||||
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
|
||||
@@ -91,8 +92,9 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
|
||||
desc = "faster" if et_torch > et_tinygrad else "slower"
|
||||
flops = save_ops*1e-6
|
||||
mem = save_mem*1e-6
|
||||
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
|
||||
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
|
||||
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:9.2f} GFLOPS {mem/et_torch:7.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:9.2f} GFLOPS {mem/et_tinygrad:7.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
|
||||
atol, rtol = (1e-2, 1e-2) if torch_dt == torch.float16 else (1e-3, 1e-3)
|
||||
np.testing.assert_allclose(val_tinygrad, val_torch, atol=atol, rtol=rtol)
|
||||
|
||||
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase):
|
||||
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
|
||||
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
|
||||
|
||||
# *** randomness ***
|
||||
|
||||
def test_random(self):
|
||||
out = Tensor.rand(10)
|
||||
for x in out.tolist():
|
||||
self.assertGreaterEqual(x, 0.0)
|
||||
self.assertLessEqual(x, 1.0)
|
||||
|
||||
# *** JIT (for Python speed) ***
|
||||
|
||||
def test_jit(self):
|
||||
|
||||
13
test/testextra/test_mockgpu.py
Normal file
13
test/testextra/test_mockgpu.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from tinygrad.helpers import getenv
|
||||
import unittest, importlib
|
||||
|
||||
@unittest.skipUnless(getenv("MOCKGPU"), 'Testing mockgpu')
|
||||
class TestMockGPU(unittest.TestCase):
|
||||
# https://github.com/tinygrad/tinygrad/pull/7627
|
||||
def test_import_typing_extensions(self):
|
||||
import extra.mockgpu.mockgpu # noqa: F401 # pylint: disable=unused-import
|
||||
import typing_extensions
|
||||
importlib.reload(typing_extensions) # pytest imports typing_extension before mockgpu
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -274,8 +274,17 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# tensor core cleanups
|
||||
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
|
||||
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
|
||||
# threefry
|
||||
# threefry + remove longs
|
||||
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
|
||||
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
|
||||
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
|
||||
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
|
||||
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
|
||||
# hacks for threefry long removal when padded (TODO: genericize)
|
||||
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
|
||||
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
|
||||
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
|
||||
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
|
||||
# arange loop folding
|
||||
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
|
||||
# indexing, with cast or where
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections import defaultdict, deque
|
||||
from typing import Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps
|
||||
from typing import Set, Tuple, List, Dict, DefaultDict
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
@@ -37,13 +38,15 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
|
||||
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
|
||||
|
||||
def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
|
||||
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]:
|
||||
def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
|
||||
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]:
|
||||
"""search the graph for all the LazyBuffers that need to realize"""
|
||||
# get all the realizes from big graph
|
||||
realizes: Dict[LazyBuffer, None] = {}
|
||||
assigns: Set[UOp] = set()
|
||||
for r in allbufs:
|
||||
if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None
|
||||
if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None
|
||||
if r.op is Ops.ASSIGN: assigns.add(ubuf)
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
reduce_of_const: List[LazyBuffer] = []
|
||||
@@ -62,7 +65,7 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
if (p:=parents.pop().base).is_realized() or p in realizes:
|
||||
if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
|
||||
if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False
|
||||
continue
|
||||
parents.extend(p.srcs)
|
||||
if forced_realize or not group:
|
||||
@@ -92,19 +95,19 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La
|
||||
top_reduce = reduceop.base.srcs[0].base
|
||||
if len(children[top_reduce]) == 1:
|
||||
del realizes[top_reduce]
|
||||
if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
for r in reduce_of_const:
|
||||
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
|
||||
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
|
||||
if any(tr.forced_realize for tr in group): continue
|
||||
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
|
||||
if len(kernel_children) == 0: continue
|
||||
for tr in group:
|
||||
del realizes[tr]
|
||||
if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
|
||||
|
||||
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
|
||||
for buf in realizes:
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
|
||||
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer])
|
||||
ubuf_realizes[ubuf] = ubuf
|
||||
return list(output_groups.values())
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import sys, atexit, functools, itertools
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
|
||||
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
@@ -91,12 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
|
||||
|
||||
# ** helpers for doing movementops on uops
|
||||
|
||||
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
|
||||
if (n:=cache.get(u)) is not None: return n
|
||||
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
|
||||
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
|
||||
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
|
||||
return ret
|
||||
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
|
||||
|
||||
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
|
||||
@@ -105,9 +101,8 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
|
||||
|
||||
# ** movementops rewrite rules
|
||||
|
||||
def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(view.st)).contiguous: return None
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg)
|
||||
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
||||
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
nv: List[View] = []
|
||||
@@ -118,10 +113,10 @@ def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
|
||||
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
|
||||
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
||||
output_shape = swizzle_st.reduce(root.axis_arg)
|
||||
@@ -134,9 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
|
||||
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
|
||||
new_shape, new_input_shape = swizzle_shapes[0]
|
||||
fixup_cache: Dict[UOp, UOp] = {}
|
||||
new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src]
|
||||
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
|
||||
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
@@ -149,8 +142,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
@@ -160,10 +152,10 @@ view_right = merge_views+PatternMatcher([
|
||||
# ASSIGN can override st
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
|
||||
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
||||
# VIEW on a reduce creates a new VIEW
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||
# non contiguous VIEW on a reduce creates a new VIEW
|
||||
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
||||
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
|
||||
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
@@ -283,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
||||
# get realizes
|
||||
realizes: Dict[UOp, UOp] = {}
|
||||
graph_rewrite(big_graph, do_realize, realizes)
|
||||
store_groups = get_realizes(outs, children, allbufs, double_reduces, realizes, ctx)
|
||||
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops)
|
||||
# split realizes into small graphs
|
||||
graph_rewrite(big_graph, break_sched, realizes)
|
||||
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]
|
||||
|
||||
@@ -2,9 +2,7 @@ from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
from typing_extensions import TypeGuard
|
||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
@@ -320,10 +318,7 @@ class trange(tqdm):
|
||||
|
||||
def _reconstruct_code(*args): return types.CodeType(*args)
|
||||
def _serialize_code(code:types.CodeType):
|
||||
# NOTE: this works in Python 3.8 and up
|
||||
if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters
|
||||
else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring',
|
||||
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
|
||||
args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up
|
||||
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
|
||||
copyreg.pickle(types.CodeType, _serialize_code)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
@@ -140,7 +140,7 @@ class Ops(FastEnum):
|
||||
|
||||
# BinaryOps
|
||||
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
|
||||
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
|
||||
|
||||
# TernaryOps
|
||||
WHERE = auto(); MULACC = auto() # noqa: E702
|
||||
@@ -168,7 +168,8 @@ class Ops(FastEnum):
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
|
||||
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
|
||||
Ops.SUB, Ops.FDIV}
|
||||
Ternary = {Ops.WHERE, Ops.MULACC}
|
||||
ALU = set.union(Unary, Binary, Ternary)
|
||||
|
||||
@@ -310,7 +311,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ret
|
||||
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
|
||||
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
|
||||
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
|
||||
def broadcast(self, count:int):
|
||||
assert self.dtype.count == 1
|
||||
@@ -346,6 +346,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
|
||||
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
@@ -379,7 +384,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is Ops.CONST: return self.arg
|
||||
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
|
||||
if self.op is Ops.VCONST: return math.gcd(*self.arg)
|
||||
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
|
||||
return 1
|
||||
@@ -630,7 +635,6 @@ class PatternMatcher:
|
||||
for p,fxn in self.patterns:
|
||||
assert p.op is not None
|
||||
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
|
||||
tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends
|
||||
real_fxn = types.FunctionType(*tuple_fxn)
|
||||
for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
|
||||
|
||||
@@ -901,7 +905,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
|
||||
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
|
||||
p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1)
|
||||
if np and (d:=functools.reduce(math.gcd, [u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
||||
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
|
||||
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
|
||||
return None
|
||||
|
||||
@@ -1050,6 +1054,8 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# group like
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# ** combine terms **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
|
||||
@@ -1,50 +1,77 @@
|
||||
from typing import Dict, Callable, List, Optional
|
||||
from llvmlite import ir
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
|
||||
from typing import List, Dict, cast
|
||||
import math, struct
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
|
||||
def ldt(dt:DType):
|
||||
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
|
||||
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
|
||||
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
|
||||
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
|
||||
|
||||
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
|
||||
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
|
||||
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
|
||||
|
||||
def cast(bb, val, input_type, output_type, bitcast=False):
|
||||
if input_type == output_type: return val
|
||||
llvm_type = dtype_to_llvm_dtype[output_type]
|
||||
if bitcast: return bb[-1].bitcast(val, llvm_type)
|
||||
|
||||
if input_type == dtypes.bfloat16:
|
||||
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
|
||||
input_type = dtypes.float32
|
||||
if output_type == dtypes.bfloat16:
|
||||
val = cast(bb, val, input_type, dtypes.float32)
|
||||
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
|
||||
def lconst(x, dtype:DType):
|
||||
if dtype in dtypes.floats:
|
||||
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
return truncate[dtype](x)
|
||||
return int(x)
|
||||
|
||||
def lcast(input_type:DType, output_type:DType):
|
||||
if dtypes.is_float(input_type):
|
||||
if dtypes.is_float(output_type):
|
||||
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
|
||||
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
|
||||
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
|
||||
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
|
||||
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
|
||||
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
|
||||
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'uitofp'
|
||||
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
|
||||
if dtypes.is_int(input_type):
|
||||
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
|
||||
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
|
||||
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
|
||||
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
|
||||
|
||||
if dtypes.is_float(output_type): return 'sitofp'
|
||||
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
|
||||
# llvm ops, lop[<dtype>][<op>]
|
||||
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
|
||||
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
|
||||
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
|
||||
flags = " nsz arcp contract afn"
|
||||
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
|
||||
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
|
||||
|
||||
llvm_rewrite = PatternMatcher([
|
||||
# memory load/store
|
||||
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
|
||||
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
|
||||
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
|
||||
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
|
||||
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
|
||||
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
|
||||
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
|
||||
|
||||
# unary/binary/ternary ops
|
||||
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
|
||||
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
|
||||
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
|
||||
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
|
||||
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
|
||||
|
||||
# range
|
||||
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
|
||||
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
|
||||
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
|
||||
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
|
||||
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
|
||||
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
|
||||
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
|
||||
|
||||
# if
|
||||
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
|
||||
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
|
||||
])
|
||||
|
||||
class LLVMRenderer(Renderer):
|
||||
device = "LLVM"
|
||||
@@ -52,101 +79,64 @@ class LLVMRenderer(Renderer):
|
||||
has_local = False
|
||||
has_shared = False
|
||||
global_max = None
|
||||
code_for_op: Dict[Ops, Callable] = {
|
||||
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
|
||||
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
|
||||
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
|
||||
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
|
||||
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
|
||||
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
|
||||
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
|
||||
BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501
|
||||
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
|
||||
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
extra_matcher = PatternMatcher([
|
||||
# rewrite RECIP with FDIV
|
||||
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
|
||||
# rewrite cast to bool to CMPNE 0
|
||||
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
|
||||
# *** also in cstyle ***
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
|
||||
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
|
||||
# rewrite MAX to CMPLT + WHERE
|
||||
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
|
||||
])
|
||||
|
||||
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
|
||||
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
|
||||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
def render(self, name: str, uops: List[UOp]) -> str:
|
||||
r: Dict[UOp, str] = {}
|
||||
args: List[str] = []
|
||||
kernel: List[str] = []
|
||||
end_lines: Dict[str, None] = {}
|
||||
vc = -1
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
|
||||
for a in func.args:
|
||||
if a.type.is_pointer: a.add_attribute("noalias")
|
||||
|
||||
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
|
||||
loop_blocks: List = []
|
||||
reduce_phis: List = []
|
||||
lvars: Dict[Optional[UOp], ir.Instruction] = {}
|
||||
|
||||
for bufname,dtype in buf_to_dtype.items():
|
||||
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
|
||||
# prealloc all assigns
|
||||
acc_to_assign: Dict[UOp, UOp] = {}
|
||||
for u in uops:
|
||||
if u.op is Ops.ASSIGN:
|
||||
vc += 1
|
||||
r[u] = r[u.src[1]] = f"%assign{vc}"
|
||||
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
|
||||
acc_to_assign[u.src[0]] = u.src[1]
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is Ops.INDEX:
|
||||
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
|
||||
elif uop is Ops.STORE:
|
||||
if len(src) > 2:
|
||||
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
else:
|
||||
bb[-1].store(lvars[src[1]], lvars[src[0]])
|
||||
elif uop is Ops.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
|
||||
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
|
||||
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
|
||||
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
|
||||
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
|
||||
|
||||
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
|
||||
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
|
||||
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
|
||||
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
|
||||
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
|
||||
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
|
||||
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
|
||||
else:
|
||||
if uop is Ops.RANGE:
|
||||
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
|
||||
bb[-2].branch(bb[-1].block)
|
||||
# if it's an assign target, it's already preallocated
|
||||
if u not in r:
|
||||
vc += 1
|
||||
r[u] = f"%v{vc}"
|
||||
|
||||
phis = []
|
||||
for rp in reduce_phis:
|
||||
incoming = lvars[rp]
|
||||
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
|
||||
lvars[rp].add_incoming(incoming, bb[-2].block)
|
||||
phis.append((rp, lvars[rp]))
|
||||
# do the rendering of the llvm ir code
|
||||
if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
|
||||
kernel.append(cast(str, l))
|
||||
|
||||
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
|
||||
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is Ops.DEFINE_ACC:
|
||||
lvars[u] = const(src[0].arg, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is Ops.LOAD:
|
||||
if len(src) > 1:
|
||||
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
|
||||
with then:
|
||||
val1 = bb[-1].load(lvars[src[0]])
|
||||
then_blk = bb[-1].block
|
||||
with otherwise: otherwise_blk = bb[-1].block
|
||||
val = bb[-1].phi(val1.type)
|
||||
val.add_incoming(val1, then_blk)
|
||||
val.add_incoming(lvars[src[1]], otherwise_blk)
|
||||
else:
|
||||
val = bb[-1].load(lvars[src[0]])
|
||||
lvars[u] = val
|
||||
elif uop is Ops.ASSIGN:
|
||||
lvars[u] = lvars[src[1]]
|
||||
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
|
||||
backward = src[0]
|
||||
while backward.op is Ops.ASSIGN: backward = backward.src[0]
|
||||
lvars[backward] = lvars[u]
|
||||
elif uop in GroupOp.ALU:
|
||||
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
|
||||
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
|
||||
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
|
||||
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
|
||||
else: raise RuntimeError(f"failed to render {uop}")
|
||||
# generate the phi nodes for the assigns
|
||||
if u.op is Ops.RANGE:
|
||||
for x in acc_to_assign:
|
||||
if u in x.src: # if this range is relevent for this acc
|
||||
vc += 1
|
||||
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
|
||||
r[x] = f"%acc{vc}"
|
||||
|
||||
bb[-1].ret_void()
|
||||
return str(module)
|
||||
# output the function
|
||||
return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())
|
||||
|
||||
@@ -5,14 +5,70 @@
|
||||
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Tuple, Optional, Dict, Any, DefaultDict
|
||||
from typing import Tuple, Optional, Dict, Any, DefaultDict, List
|
||||
from collections import defaultdict
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
|
||||
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
|
||||
|
||||
# ***** API *****
|
||||
|
||||
class CloudRequest: pass
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyOut(CloudRequest): buffer_num: int
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProgramExec(CloudRequest):
|
||||
name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702
|
||||
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
|
||||
|
||||
# for safe deserialization
|
||||
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
|
||||
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
|
||||
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
|
||||
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
|
||||
def safe_eval(node): return eval_fxns[node.__class__](node)
|
||||
|
||||
class BatchRequest:
|
||||
def __init__(self):
|
||||
self._q: List[CloudRequest] = []
|
||||
self._h: Dict[str, bytes] = {}
|
||||
def h(self, d:bytes) -> str:
|
||||
binhash = hashlib.sha256(d).digest()
|
||||
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
|
||||
return datahash
|
||||
def q(self, x:CloudRequest): self._q.append(x)
|
||||
def serialize(self) -> bytes:
|
||||
self.h(repr(self._q).encode())
|
||||
return b''.join(self._h.values())
|
||||
def deserialize(self, dat:bytes) -> BatchRequest:
|
||||
ptr = 0
|
||||
while ptr < len(dat):
|
||||
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
|
||||
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
|
||||
ptr += 0x28+datalen
|
||||
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
|
||||
return self
|
||||
|
||||
# ***** backend *****
|
||||
|
||||
@@ -21,7 +77,6 @@ class CloudSession:
|
||||
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
|
||||
# TODO: the buffer should track this internally
|
||||
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
|
||||
buffer_num = 0
|
||||
|
||||
class CloudHandler(BaseHTTPRequestHandler):
|
||||
protocol_version = 'HTTP/1.1'
|
||||
@@ -32,66 +87,47 @@ class CloudHandler(BaseHTTPRequestHandler):
|
||||
super().setup()
|
||||
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
|
||||
|
||||
def get_data(self):
|
||||
content_len = self.headers.get('Content-Length')
|
||||
assert content_len is not None
|
||||
return self.rfile.read(int(content_len))
|
||||
def get_json(self): return json.loads(self.get_data())
|
||||
|
||||
def _fail(self):
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
return 0
|
||||
|
||||
def _do(self, method):
|
||||
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
|
||||
ret = b""
|
||||
if self.path == "/renderer" and method == "GET":
|
||||
ret, status_code = b"", 200
|
||||
if self.path == "/batch" and method == "POST":
|
||||
# TODO: streaming deserialize?
|
||||
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
|
||||
# the cmds are always last (currently in datahash)
|
||||
for c in req._q:
|
||||
if DEBUG >= 1: print(c)
|
||||
match c:
|
||||
case BufferAlloc():
|
||||
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
|
||||
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
|
||||
case BufferFree():
|
||||
buf,sz,buffer_options = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[c.buffer_num]
|
||||
case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
|
||||
case CopyOut():
|
||||
buf,sz,_ = session.buffers[c.buffer_num]
|
||||
Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
case ProgramAlloc():
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
|
||||
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
|
||||
case ProgramFree(): del session.programs[(c.name, c.datahash)]
|
||||
case ProgramExec():
|
||||
bufs = [session.buffers[x][0] for x in c.bufs]
|
||||
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
|
||||
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif self.path == "/renderer" and method == "GET":
|
||||
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
|
||||
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
|
||||
elif self.path.startswith("/alloc") and method == "POST":
|
||||
size = int(self.path.split("=")[-1])
|
||||
buffer_options: Optional[BufferOptions] = None
|
||||
if 'image' in self.path:
|
||||
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
|
||||
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
|
||||
session.buffer_num += 1
|
||||
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
|
||||
ret = str(session.buffer_num).encode()
|
||||
elif self.path.startswith("/buffer"):
|
||||
key = int(self.path.split("/")[-1])
|
||||
buf,sz,buffer_options = session.buffers[key]
|
||||
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
|
||||
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
|
||||
elif method == "DELETE":
|
||||
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
|
||||
del session.buffers[key]
|
||||
else: return self._fail()
|
||||
elif self.path.startswith("/program"):
|
||||
name, hsh = self.path.split("/")[-2:]
|
||||
if method == "PUT":
|
||||
src = self.get_data()
|
||||
assert hashlib.sha256(src).hexdigest() == hsh
|
||||
lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode())
|
||||
session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib)
|
||||
elif method == "POST":
|
||||
j = self.get_json()
|
||||
bufs = [session.buffers[x][0] for x in j['bufs']]
|
||||
del j['bufs']
|
||||
r = session.programs[(name, hsh)](*bufs, **j)
|
||||
if r is not None: ret = str(r).encode()
|
||||
elif method == "DELETE": del session.programs[(name, hsh)]
|
||||
else: return self._fail()
|
||||
else: return self._fail()
|
||||
self.send_response(200)
|
||||
else: status_code = 404
|
||||
self.send_response(status_code)
|
||||
self.send_header('Content-Length', str(len(ret)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(ret)
|
||||
|
||||
def do_GET(self): return self._do("GET")
|
||||
def do_POST(self): return self._do("POST")
|
||||
def do_PUT(self): return self._do("PUT")
|
||||
def do_DELETE(self): return self._do("DELETE")
|
||||
|
||||
def cloud_server(port:int):
|
||||
multiprocessing.current_process().name = "MainProcess"
|
||||
@@ -106,44 +142,46 @@ class CloudAllocator(Allocator):
|
||||
def __init__(self, device:CloudDevice):
|
||||
self.device = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options) -> int:
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
|
||||
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
|
||||
def _free(self, opaque, options):
|
||||
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
|
||||
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
|
||||
def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src))
|
||||
# TODO: ideally we shouldn't have to deal with images here
|
||||
def _alloc(self, size:int, options:BufferOptions) -> int:
|
||||
self.device.buffer_num += 1
|
||||
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
|
||||
return self.device.buffer_num
|
||||
# TODO: options should not be here in any Allocator
|
||||
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
|
||||
def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
|
||||
def copyout(self, dest:memoryview, src:int):
|
||||
resp = self.device.send("GET", f"buffer/{src}")
|
||||
self.device.req.q(CopyOut(src))
|
||||
resp = self.device.batch_submit()
|
||||
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
|
||||
dest[:] = resp
|
||||
|
||||
class CloudProgram:
|
||||
def __init__(self, device:CloudDevice, name:str, lib:bytes):
|
||||
self.device = device
|
||||
self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}"
|
||||
self.device.send("PUT", "program/"+self.prgid, lib)
|
||||
self.device, self.name = device, name
|
||||
self.datahash = self.device.req.h(lib)
|
||||
self.device.req.q(ProgramAlloc(self.name, self.datahash))
|
||||
super().__init__()
|
||||
def __del__(self): self.device.send("DELETE", "program/"+self.prgid)
|
||||
def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash))
|
||||
|
||||
def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
|
||||
args = {"bufs": bufs, "vals": vals, "wait": wait}
|
||||
if global_size is not None: args["global_size"] = global_size
|
||||
if local_size is not None: args["local_size"] = local_size
|
||||
ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode())
|
||||
if wait: return float(ret)
|
||||
self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
|
||||
if wait: return float(self.device.batch_submit())
|
||||
|
||||
class CloudDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
if (host:=getenv("HOST", "")) != "":
|
||||
self.host = host
|
||||
if (host:=getenv("HOST", "")) != "": self.host = host
|
||||
else:
|
||||
p = multiprocessing.Process(target=cloud_server, args=(6667,))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.host = "127.0.0.1:6667"
|
||||
self.cookie = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
|
||||
# state for the connection
|
||||
self.session = binascii.hexlify(os.urandom(0x10)).decode()
|
||||
self.buffer_num = 0
|
||||
self.req: BatchRequest = BatchRequest()
|
||||
|
||||
if DEBUG >= 1: print(f"cloud with host {self.host}")
|
||||
while 1:
|
||||
try:
|
||||
@@ -155,13 +193,26 @@ class CloudDevice(Compiled):
|
||||
time.sleep(0.1)
|
||||
if DEBUG >= 1: print(f"remote has device {clouddev}")
|
||||
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
|
||||
assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}"
|
||||
renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2])
|
||||
super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self))
|
||||
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
|
||||
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
|
||||
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
|
||||
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
|
||||
|
||||
def __del__(self):
|
||||
# TODO: this is never being called
|
||||
# TODO: should close the whole session
|
||||
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
|
||||
|
||||
def batch_submit(self):
|
||||
data = self.req.serialize()
|
||||
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
|
||||
ret = self.send("POST", "batch", data)
|
||||
self.req = BatchRequest()
|
||||
return ret
|
||||
|
||||
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
|
||||
# TODO: retry logic
|
||||
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"})
|
||||
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
|
||||
response = self.conn.getresponse()
|
||||
assert response.status == 200, f"failed on {method} {path}"
|
||||
return response.read()
|
||||
|
||||
@@ -2,18 +2,24 @@ from __future__ import annotations
|
||||
import ctypes, functools
|
||||
from typing import Tuple
|
||||
from tinygrad.device import Compiled, Compiler, MallocAllocator
|
||||
from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump
|
||||
from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump, getenv
|
||||
from tinygrad.renderer.llvmir import LLVMRenderer
|
||||
import llvmlite.binding as llvm
|
||||
|
||||
class LLVMCompiler(Compiler):
|
||||
def __init__(self, device:LLVMDevice):
|
||||
def __init__(self, device:LLVMDevice, opt:bool=False):
|
||||
self.device = device
|
||||
super().__init__("compile_llvm")
|
||||
self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
|
||||
self.device.target_machine.add_analysis_passes(self.optimizer)
|
||||
if opt:
|
||||
with llvm.create_pass_manager_builder() as builder:
|
||||
builder.opt_level = 3; builder.size_level = 0; builder.loop_vectorize = True; builder.slp_vectorize = True # noqa: E702
|
||||
builder.populate(self.optimizer)
|
||||
super().__init__("compile_llvm_opt" if opt else "compile_llvm")
|
||||
def compile(self, src:str) -> bytes:
|
||||
mod = llvm.parse_assembly(src)
|
||||
mod.verify()
|
||||
self.device.optimizer.run(mod)
|
||||
self.optimizer.run(mod)
|
||||
if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod))
|
||||
return self.device.target_machine.emit_object(mod)
|
||||
|
||||
@@ -23,6 +29,7 @@ class LLVMProgram:
|
||||
self.name, self.lib = name, lib
|
||||
device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
|
||||
self.fxn = device.engine.get_function_address(name)
|
||||
assert self.fxn != 0, "LLVM failed to get function address"
|
||||
|
||||
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
|
||||
if not hasattr(self, 'cfunc'):
|
||||
@@ -35,12 +42,9 @@ class LLVMDevice(Compiled):
|
||||
llvm.initialize_native_target()
|
||||
llvm.initialize_native_asmprinter()
|
||||
llvm.initialize_native_asmparser()
|
||||
self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
|
||||
# this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
|
||||
self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2)
|
||||
self.target_machine.add_analysis_passes(self.optimizer)
|
||||
self.target_machine.set_asm_verbosity(True)
|
||||
backing_mod = llvm.parse_assembly(str())
|
||||
backing_mod.triple = llvm.get_process_triple()
|
||||
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
|
||||
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self, getenv("LLVMOPT")), functools.partial(LLVMProgram, self))
|
||||
|
||||
@@ -155,7 +155,7 @@ class QCOMComputeQueue(HWComputeQueue):
|
||||
|
||||
if args_state.prg.tex_cnt > 0:
|
||||
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt),
|
||||
state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)),
|
||||
*data64_le(args_state.ptr + args_state.prg.tex_off))
|
||||
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
|
||||
|
||||
@@ -247,14 +247,15 @@ class QCOMProgram(HCQProgram):
|
||||
self.buf_info, self.consts_info = [], []
|
||||
|
||||
# Collect sampler info.
|
||||
self.samp_cnt = _read_lib(image_desc_off + 0xdc)
|
||||
self.samp_cnt = samp_cnt_in_file = _read_lib(image_desc_off + 0xdc)
|
||||
assert self.samp_cnt <= 1, "Up to one sampler supported"
|
||||
if self.samp_cnt:
|
||||
self.samp_cnt += 1
|
||||
self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode),
|
||||
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0]
|
||||
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0, 0, 0, 0, 0]
|
||||
|
||||
# Collect kernel arguments (buffers) info.
|
||||
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt
|
||||
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samp_cnt_in_file
|
||||
while bdoff + 32 <= len(self.lib):
|
||||
length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32])
|
||||
if length == 0: break
|
||||
@@ -263,7 +264,7 @@ class QCOMProgram(HCQProgram):
|
||||
|
||||
# Setting correct offsets to textures/ibos.
|
||||
self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info)
|
||||
self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt
|
||||
self.ibo_off, self.tex_off, self.samp_off = 2048, 2048 + 0x40 * self.ibo_cnt, 2048 + 0x40 * self.tex_cnt + 0x40 * self.ibo_cnt
|
||||
cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off
|
||||
for x in self.buf_info:
|
||||
if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40
|
||||
@@ -307,10 +308,10 @@ class QCOMAllocator(HCQAllocator):
|
||||
texture.pitch, texture.real_stride = pitch, real_stride
|
||||
|
||||
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
|
||||
texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
|
||||
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
|
||||
texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
|
||||
texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
|
||||
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
|
||||
|
||||
return texture
|
||||
|
||||
@@ -464,9 +464,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
|
||||
|
||||
@staticmethod
|
||||
def _threefry_random_bits(key, counts0, counts1):
|
||||
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
|
||||
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
|
||||
x = F.Threefry.apply(*x._broadcasted(key))
|
||||
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
|
||||
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
|
||||
return counts0.cat(counts1)
|
||||
|
||||
@@ -494,9 +494,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
|
||||
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
|
||||
device=device, dtype=dtypes.uint64, requires_grad=False)
|
||||
Tensor._device_seeds[device] = Tensor(
|
||||
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
|
||||
device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
had_counter = False
|
||||
else: had_counter = True
|
||||
@@ -612,6 +612,26 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
|
||||
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
|
||||
"""
|
||||
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
|
||||
|
||||
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
|
||||
Additionally, all other keyword arguments are passed to the constructor of the tensor.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.linspace(0, 10, 5).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.linspace(-1, 1, 5).numpy())
|
||||
```
|
||||
"""
|
||||
if steps < 0: raise ValueError("number of steps must be non-negative")
|
||||
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
|
||||
if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
|
||||
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
|
||||
|
||||
@staticmethod
|
||||
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
|
||||
"""
|
||||
@@ -734,7 +754,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
|
||||
```
|
||||
"""
|
||||
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
|
||||
dtype = kwargs.pop("dtype", dtypes.int32)
|
||||
dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
|
||||
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
|
||||
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user