Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2024-11-12 02:45:23 -08:00
32 changed files with 699 additions and 358 deletions

View File

@@ -47,55 +47,55 @@ jobs:
echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
run: python3.11 test/external/process_replay/reset.py
- name: Run Stable Diffusion
run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
run: JIT=2 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run Stable Diffusion with fp16
run: JIT=2 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
run: JIT=2 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt
- name: Run SDXL
run: JIT=2 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: JIT=2 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run model inference benchmark
run: METAL=1 python3 test/external/external_model_benchmark.py
run: METAL=1 python3.11 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores
run: METAL=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM
run: |
DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt
DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt
DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt
DEBUG=2 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt
- name: Fuzz Padded Tensor Core GEMM
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py
- name: Run LLaMA
run: |
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
- name: Run LLaMA with BEAM
run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
- name: Run quantized LLaMA
run: |
python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt
- name: Run LLaMA 7B on 4 (virtual) GPUs
run: python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt
- name: Run GPT2
run: |
JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt
JIT=1 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt
- name: Run GPT2 w HALF
run: HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
run: JIT=2 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt
run: JIT=2 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt
- name: Run 10 CIFAR training steps w HALF
run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt
run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt
#- name: Run 10 CIFAR training steps w BF16
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run 10 CIFAR training steps w winograd
run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
run: JIT=2 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (Mac)
@@ -123,7 +123,7 @@ jobs:
train_cifar_bf16.txt
train_cifar_wino.txt
- name: Run process replay tests
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3.11 process_replay.py
testnvidiabenchmark:
name: tinybox green Benchmark
@@ -157,7 +157,7 @@ jobs:
- name: Run model inference benchmark
run: NV=1 RUN_PROCESS_REPLAY=0 NOCLANG=1 python3 test/external/external_model_benchmark.py
- name: Test speed vs torch
run: NV=1 RUN_PROCESS_REPLAY=0 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
run: NV=1 RUN_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt
- name: Test tensor cores
run: |
NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
@@ -170,8 +170,10 @@ jobs:
run: NV=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt
- name: Run Tensor Core GEMM (NV)
run: NV=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt
- name: Run Tensor Core GEMM (NV) with BEAM
run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
# - name: Run Tensor Core GEMM (NV) with BEAM
# run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py
- name: Test speed vs theoretical
run: NV=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py
- name: Run Stable Diffusion
run: NV=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
@@ -328,6 +330,10 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: setup perflevel
run: |
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
rocm-smi
- name: Show off tinybox
run: /opt/rocm/bin/rocm-bandwidth-test
# TODO: unstable on AMD
@@ -343,6 +349,8 @@ jobs:
AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Run Tensor Core GEMM (AMD)
run: AMD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt
- name: Test speed vs theoretical
run: AMD=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py
# TODO: AMD compiler bug causes this to fail
#- name: Fuzz Padded Tensor Core GEMM
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
@@ -431,6 +439,10 @@ jobs:
rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal
- name: reset process replay
run: test/external/process_replay/reset.py
- name: setup perflevel
run: |
examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh
rocm-smi
- name: Train MNIST
run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps

View File

@@ -209,15 +209,15 @@ jobs:
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: "3.10"
- name: Cache python packages
uses: actions/cache@v4
with:
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.10/site-packages
key: linting-packages-${{ hashFiles('**/setup.py') }}-3.10
- name: Install dependencies
run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Lint bad-indentation and trailing-whitespace with pylint

View File

@@ -46,4 +46,5 @@ FLOAT16 | [1] | use float16 for images instead of float32
PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend.
PROFILE | [1] | enable output of [perfetto](https://ui.perfetto.dev/) compatible profile. This feature is supported in NV and AMD backends.
VISIBLE_DEVICES | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0).
JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled
JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled
VIZ | [1] | 0=disabled, 1=[viz enabled](../tinygrad/viz/README)

View File

@@ -5,6 +5,7 @@
::: tinygrad.Tensor.ones
::: tinygrad.Tensor.full
::: tinygrad.Tensor.arange
::: tinygrad.Tensor.linspace
::: tinygrad.Tensor.eye
::: tinygrad.Tensor.full_like
::: tinygrad.Tensor.zeros_like

View File

@@ -178,7 +178,7 @@ class TrackedMemoryView:
def __len__(self): return len(self.mv)
def __repr__(self): return repr(self.mv)
def _memoryview(mem):
def _memoryview(cls, mem):
if isinstance(mem, int) or isinstance(mem, ctypes.Array):
addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem
for d in drivers:
@@ -196,7 +196,7 @@ install_hook(libc.lseek64, _lseek64)
install_hook(libc.stat64, _stat64)
install_hook(libc.fstat64, _fstat64)
install_hook(libc.getdents64, _getdents64)
builtins.memoryview = _memoryview # type: ignore
builtins.memoryview = type("memoryview", (), {'__new__': _memoryview}) # type: ignore
# rewrite autogen's libc mmaps functions.
import tinygrad.runtime.autogen.libc as autogen_libc

View File

@@ -1,6 +1,6 @@
from tinygrad import Tensor, dtypes
from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm
from tinygrad.device import is_dtype_supported
from typing import Optional, Union, List, Any, Tuple
import math
@@ -9,7 +9,8 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000):
half = dim // 2
freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp()
args = timesteps.unsqueeze(1) * freqs.unsqueeze(0)
return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16)
out = Tensor.cat(args.cos(), args.sin(), dim=-1)
return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out
class ResBlock:
def __init__(self, channels:int, emb_channels:int, out_channels:int):
@@ -222,16 +223,17 @@ class UNetModel:
]
def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor:
t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16)
t_emb = timestep_embedding(tms, self.model_ch)
emb = t_emb.sequential(self.time_embed)
if y is not None:
assert y.shape[0] == x.shape[0]
emb = emb + y.sequential(self.label_emb[0])
emb = emb.cast(dtypes.float16)
ctx = ctx.cast(dtypes.float16)
x = x .cast(dtypes.float16)
if is_dtype_supported(dtypes.float16):
emb = emb.cast(dtypes.float16)
ctx = ctx.cast(dtypes.float16)
x = x .cast(dtypes.float16)
def run(x:Tensor, bb) -> Tensor:
if isinstance(bb, ResBlock): x = bb(x, emb)

View File

@@ -7,6 +7,7 @@ from tinygrad import Tensor, dtypes, Device
from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from tinygrad.dtype import ConstType, DType
from tinygrad.device import is_dtype_supported
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
@@ -29,14 +30,6 @@ def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Uni
cache_misses = info.misses
return ret
# copied from helpers.py
def is_dtype_supported(dtype, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
# src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22
# TODO: use dtypes.float16 for FLOAT16

View File

@@ -1,6 +1,7 @@
# type: ignore
import ctypes, ctypes.util, struct, fcntl, re
from hexdump import hexdump
from copy import deepcopy
import pathlib, sys
from tinygrad.helpers import to_mv, getenv
from tinygrad.runtime.autogen import adreno
@@ -91,27 +92,36 @@ def parse_cmd_buf(dat):
num_unit = vals[0]>>22
if IOCTL > 0: print(f"{num_unit=} {state_block=} {state_src=} {state_type=} {dst_off=}")
if state_block == SB6_CS_SHADER and IOCTL > 2:
if "LOAD_FRAGS" not in CAPTURED_STATE: CAPTURED_STATE['LOAD_FRAGS'] = []
CAPTURED_STATE['LOAD_FRAGS'].append((state_block, state_type, num_unit, dst_off))
if state_block == SB6_CS_SHADER:
from extra.disassemblers.adreno import disasm_raw
if state_type == ST6_SHADER: disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128))
if state_type == ST6_CONSTANTS: hexdump(get_mem(((vals[2] << 32) | vals[1]), min(0x180, num_unit*4)))
if state_type == ST6_SHADER and IOCTL > 2:
disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128))
if state_type == ST6_CONSTANTS:
x = get_mem(((vals[2] << 32) | vals[1]), num_unit*4)
CAPTURED_STATE['constants'] = x[:]
if IOCTL > 2:
print('constants')
hexdump(x)
if state_type == ST6_IBO:
ibos_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4)
CAPTURED_STATE['ibos'] = ibos_bytes[:]
if IOCTL > 0:
if IOCTL > 1:
print('texture ibos')
hexdump(ibos_bytes)
elif state_block == SB6_CS_TEX and IOCTL > 2:
elif state_block == SB6_CS_TEX:
if state_type == ST6_SHADER:
samplers_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 4 * 4)
CAPTURED_STATE['samplers'] = ibos_bytes[:]
if IOCTL > 0:
CAPTURED_STATE['samplers'] = samplers_bytes[:]
if IOCTL > 1:
print('texture samplers')
hexdump(samplers_bytes)
if state_type == ST6_CONSTANTS:
descriptors_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4)
CAPTURED_STATE['descriptors'] = ibos_bytes[:]
if IOCTL > 0:
descriptors_bytes = get_mem((vals[2] << 32) | vals[1], 1600)
CAPTURED_STATE['descriptors'] = descriptors_bytes[:]
if IOCTL > 1:
print('texture descriptors')
hexdump(descriptors_bytes)
@@ -200,10 +210,19 @@ install_hook(libc.ioctl, ioctl)
def before_launch():
global CAPTURED_STATE
CAPTURED_STATE.clear()
def collect_last_launch_state(): return CAPTURED_STATE
def collect_last_launch_state():
global CAPTURED_STATE
return deepcopy(CAPTURED_STATE)
def compare_launch_state(state, good_state):
cmp = [
(adreno.REG_A6XX_SP_CS_CONFIG, 0xffffffff),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NTEX__MASK),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NSAMP__MASK),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NIBO__MASK),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_ENABLED),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_TEX),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_SAMP),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_IBO),
(adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_UBO),
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_HALFREGFOOTPRINT__MASK),
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_FULLREGFOOTPRINT__MASK),
@@ -213,20 +232,53 @@ def compare_launch_state(state, good_state):
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_EARLYPREAMBLE),
(adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_MERGEDREGS),
(adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_MEMSIZEPERITEM__MASK),
(adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_HWSTACKSIZEPERTHREAD__MASK),
(adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK5),
(adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK6),
(adreno.REG_A6XX_SP_CS_BRANCH_COND, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_KERNELDIM__MASK),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEX__MASK),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEY__MASK),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEZ__MASK),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_1, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_2, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_3, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_4, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_5, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_NDRANGE_6, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_CNTL_0, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_CNTL_1, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_X, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Y, 0xffffffff),
(adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Z, 0xffffffff),
]
for x,m in cmp:
print(f"Field {REGS[x]}, mask: 0x{m:X} cmp: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}")
if state.get(x, 0) & m != good_state.get(x, 0) & m:
return False, f"Field {REGS[x]}, mask: {x:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}"
return False, f"Field {REGS[x]}, mask: 0x{m:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}"
for n in ['ibos', 'samplers', 'descriptors']:
for n in ['descriptors', 'ibos']:
if n not in good_state: continue
mv1, mv2 = state.get(n), good_state.get(n)
if len(mv1) != len(mv2): return False, f"{n}: len mismatch"
if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}"
mv1 = memoryview(bytearray(mv1)).cast('I')
mv2 = memoryview(bytearray(mv2)).cast('I')
for i in range(len(mv2)):
if i % 8 == 5 or i % 8 == 4: continue # addresses
if mv1[i]!=mv2[i]: return False, f"{n}: content mismatch {i} {mv1[i]} {mv2[i]}"
for n in ['samplers']:
if n not in good_state: continue
mv1, mv2 = state.get(n), good_state.get(n)
if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}"
if any(mv1[i]!=mv2[i] for i in range(len(mv1))): return False, f"{n}: content mismatch"
return True, "PASS"

View File

@@ -22,7 +22,7 @@ setup(name='tinygrad',
"License :: OSI Approved :: MIT License"
],
install_requires=[],
python_requires='>=3.8',
python_requires='>=3.10',
extras_require={
'llvm': ["llvmlite"],
'arm': ["unicorn"],

View File

@@ -1,8 +1,8 @@
FROM ubuntu:20.04
FROM ubuntu:22.04
# Install python3.8, and pip3
# Install python3.10, and pip3
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.8 \
python3.10 \
python3-pip \
&& rm -rf /var/lib/apt/lists/*

File diff suppressed because one or more lines are too long

View File

@@ -30,7 +30,7 @@ from tinygrad.device import is_dtype_supported
def on_linearizer_will_run(): pass
def on_linearizer_did_run(): pass
def compare_states(x, y): return True
def compare_states(x, y): return (True, "")
if getenv("VALIDATE_HCQ"):
if Device.DEFAULT == "NV":

51
test/external/speed_v_theoretical.py vendored Normal file
View File

@@ -0,0 +1,51 @@
import unittest, time
from tinygrad import Tensor, TinyJit, Device
from tinygrad.helpers import Context, DEBUG
class TestKernelSpeed(unittest.TestCase):
def _test_matmul(self, M, N=None, K=None, nv=None, amd=None):
# (MxK) @ (KxN)
@TinyJit
def f(a, b): return (a @ b).realize()
if N is None: N = M
if K is None: K = M
tms = []
with Context(BEAM=3):
for _ in range(10):
with Context(BEAM=0, DEBUG=0):
a = Tensor.rand(M, K, dtype="half").realize()
b = Tensor.rand(K, N, dtype="half").realize()
Device.default.synchronize()
st = time.perf_counter()
_c = f(a, b)
Device.default.synchronize()
tms.append(time.perf_counter() - st)
ops = 2 * M * N * K
tm = min(tms)
tflops = ops / tm / 1e12
if DEBUG >= 1:
print(f"{tm=}")
print(f"{tflops=}")
if Device.DEFAULT == "NV":
if DEBUG >=1: print(f"target: {nv}")
self.assertGreater(tflops, nv)
if Device.DEFAULT == "AMD":
if DEBUG >=1: print(f'target: {amd}')
self.assertGreater(tflops, amd)
# TODO: smaller ones has other overhead in synchronize
# TODO: AMD number can be better (perf level?)
def test_gemm_1024(self): self._test_matmul(1024, nv=8, amd=7)
def test_gemm_2048(self): self._test_matmul(2048, nv=50, amd=30)
def test_gemm_4096(self): self._test_matmul(4096, nv=95, amd=65)
def test_gemm_8192(self): self._test_matmul(8192, nv=130, amd=70)
# TODO: add gemv, which is memory bounded
if __name__ == '__main__':
unittest.main()

View File

@@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase):
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg
a = Tensor.empty((4,4))
b = Tensor.empty((4,4))

View File

@@ -94,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
prg = k.to_program()
print(prg.src)
if_uops = [u for u in k.uops if u.op is Ops.IF]
self.assertIn(len(if_uops), {1,3})
self.assertIn(len(if_uops), {1,2,3})
conditions = if_uops[0].src[0].sparents
self.assertLessEqual(len(conditions), 9)

View File

@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
assert k is not None
ifs = [u for u in k.uops if u.op is Ops.IF]
self.assertEqual(len(ifs), 4)
self.assertEqual(len(ifs), 3)
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)

View File

@@ -233,6 +233,20 @@ class TestOps(unittest.TestCase):
def test_arange_4096(self):
helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True)
def test_linspace(self):
helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True)
helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True)
helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype="int32"), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, 10, 20, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 20, dtype="int32"), forward_only=True)
helper_test_op([], lambda: torch.linspace(5, -5, 20, dtype=torch.int32), lambda: Tensor.linspace(5, -5, 20, dtype="int32"), forward_only=True)
self.helper_test_exception([], lambda: torch.linspace(5, 10, 3, dtype=torch.bool), lambda: Tensor.linspace(5, 10, 3, dtype="bool"),
expected=(RuntimeError, ValueError))
self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError))
def test_sum_fake(self):
helper_test_op([(256, 1)], lambda x: x.sum(axis=1))
@@ -1483,6 +1497,11 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),
lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5)
def test_simple_conv2d_bias(self):
helper_test_op([(1,4,9,9), (4,4,3,3), (4,)],
lambda x,w,b: torch.nn.functional.conv2d(x,w,b).relu(),
lambda x,w,b: Tensor.conv2d(x,w,b).relu(), grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv3d on images")
def test_simple_conv3d(self):
helper_test_op([(1,4,9,9,9), (4,4,3,3,3)],

View File

@@ -6,6 +6,7 @@ import torch
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
from tinygrad.helpers import getenv, CI
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import lower_schedule, CompiledRunner
from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase):
counts = Tensor.arange(20, dtype=dtypes.uint32)
counts0, counts1 = counts.chunk(2)
r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy()
r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy()
np.testing.assert_allclose(jr, r)
def test_threefry_doesnt_use_long(self):
for ei in lower_schedule(Tensor.rand(20).schedule()):
if isinstance(ei.prg, CompiledRunner):
for u in ei.prg.p.uops:
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)
@@ -227,11 +234,15 @@ class TestRandomness(unittest.TestCase):
def test_randint(self):
self.assertFalse(normal_test(Tensor.randint))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5),
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"),
numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG")
# check types of args
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3)
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5)
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float")
with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32)
def test_normal(self):

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, timeit
@@ -1626,13 +1626,14 @@ class TestIndexing(unittest.TestCase):
ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy()
np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6)
def test_recursive_st_fixup(self):
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {})
swizzle = ast.src[0].src[2].reshape((4, 1))
new_uop = swizzle_rewrite(swizzle)
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertLess(et, 1e3)
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_strongly_connected_DAG(self):
val = 1.0
@@ -1864,5 +1865,47 @@ class TestSwizzle(unittest.TestCase):
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
@unittest.expectedFailure
def test_swizzle_failure_permute(self):
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(20, ('METAL', 65, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.ADD, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 2925, dtypes.float)), src=()),
x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=(
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(Ops.CONST, dtypes.float, arg=1.0, src=()),
x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)),
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()),
x15,)),)),
x6,)),)),
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=(
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.WHERE, dtypes.float, arg=None, src=(
x12,
UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()),
x15,)),
UOp(Ops.MUL, dtypes.float, arg=None, src=(
UOp(Ops.PRELOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 2925, dtypes.float)), src=()),
x10,)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(4, ('METAL', 2925, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),))
ret = swizzle_rewrite(sink)
self.assertEqual(swizzle_cnt(ret), 0)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@@ -38,6 +38,7 @@ def helper_test_speed(f1, *args):
# operation cache defeats
args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args]
args = [(x-1).realize() if isinstance(x, Tensor) else (None if x is None else (x-1)) for x in args]
# force syncing
[x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None]
@@ -91,8 +92,9 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args):
desc = "faster" if et_torch > et_tinygrad else "slower"
flops = save_ops*1e-6
mem = save_mem*1e-6
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3)
print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:9.2f} GFLOPS {mem/et_torch:7.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:9.2f} GFLOPS {mem/et_tinygrad:7.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501
atol, rtol = (1e-2, 1e-2) if torch_dt == torch.float16 else (1e-3, 1e-3)
np.testing.assert_allclose(val_tinygrad, val_torch, atol=atol, rtol=rtol)
def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x):
torch.manual_seed(0)

View File

@@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase):
self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N))
if IMAGE < 2: self.assertEqual(out.dtype, out_dtype)
# *** randomness ***
def test_random(self):
out = Tensor.rand(10)
for x in out.tolist():
self.assertGreaterEqual(x, 0.0)
self.assertLessEqual(x, 1.0)
# *** JIT (for Python speed) ***
def test_jit(self):

View File

@@ -0,0 +1,13 @@
from tinygrad.helpers import getenv
import unittest, importlib
@unittest.skipUnless(getenv("MOCKGPU"), 'Testing mockgpu')
class TestMockGPU(unittest.TestCase):
# https://github.com/tinygrad/tinygrad/pull/7627
def test_import_typing_extensions(self):
import extra.mockgpu.mockgpu # noqa: F401 # pylint: disable=unused-import
import typing_extensions
importlib.reload(typing_extensions) # pytest imports typing_extension before mockgpu
if __name__ == '__main__':
unittest.main()

View File

@@ -274,8 +274,17 @@ sym = symbolic_flat+PatternMatcher([
# tensor core cleanups
(UPat.var("add") + UPat(Ops.WMMA, name="wmma"),
lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)),
# threefry
# threefry + remove longs
(UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32),
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize)
((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation
(((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y),
(((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x),
# hacks for threefry long removal when padded (TODO: genericize)
(UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)),
lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)),
((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32),
lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))),
# arange loop folding
(acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse),
# indexing, with cast or where

View File

@@ -1,6 +1,7 @@
from collections import defaultdict, deque
from typing import Tuple, List, Dict, DefaultDict
from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps
from typing import Set, Tuple, List, Dict, DefaultDict
from tinygrad.device import Buffer
from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.lazy import LazyBuffer
@@ -37,13 +38,15 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={})
return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants])
def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]:
def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None],
double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]:
"""search the graph for all the LazyBuffers that need to realize"""
# get all the realizes from big graph
realizes: Dict[LazyBuffer, None] = {}
assigns: Set[UOp] = set()
for r in allbufs:
if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None
if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None
if r.op is Ops.ASSIGN: assigns.add(ubuf)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
reduce_of_const: List[LazyBuffer] = []
@@ -62,7 +65,7 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La
parents = deque((r, *group))
while parents and not forced_realize:
if (p:=parents.pop().base).is_realized() or p in realizes:
if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False
if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False
continue
parents.extend(p.srcs)
if forced_realize or not group:
@@ -92,19 +95,19 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La
top_reduce = reduceop.base.srcs[0].base
if len(children[top_reduce]) == 1:
del realizes[top_reduce]
if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
for r in reduce_of_const:
group = {tr:None for tr,rop in reduce_for_op.items() if rop is r}
if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue
if any(tr.forced_realize for tr in group): continue
kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}}
if len(kernel_children) == 0: continue
for tr in group:
del realizes[tr]
if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf]
output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list)
for buf in realizes:
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer])
output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer])
ubuf_realizes[ubuf] = ubuf
return list(output_groups.values())

View File

@@ -1,9 +1,9 @@
import sys, atexit, functools, itertools
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast
from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint
from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View, strides_for_shape
@@ -91,12 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce
# ** helpers for doing movementops on uops
def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp:
if (n:=cache.get(u)) is not None: return n
if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg))
if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u
cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src))
return ret
def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left)
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
@@ -105,9 +101,8 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr
# ** movementops rewrite rules
def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
if (st:=unwrap(view.st)).contiguous: return None
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg)
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv: List[View] = []
@@ -118,10 +113,10 @@ def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]:
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, r.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st)
assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE"
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
output_shape = swizzle_st.reduce(root.axis_arg)
@@ -134,9 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles]
assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}"
new_shape, new_input_shape = swizzle_shapes[0]
fixup_cache: Dict[UOp, UOp] = {}
new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src]
ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg)
ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src))
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
@@ -149,8 +142,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),),
# push VIEW to loads
view_left = merge_views+PatternMatcher([
# VIEW before elementwise ops
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))),
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))),
# early merge VIEW buffer ops
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])
@@ -160,10 +152,10 @@ view_right = merge_views+PatternMatcher([
# ASSIGN can override st
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))),
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
# VIEW on a reduce creates a new VIEW
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# non contiguous VIEW on a reduce creates a new VIEW
(UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# push a VIEW down to STORE, through a reduce (ONLY reshapes)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
(UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce),
# push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes)
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise),
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
@@ -283,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
# get realizes
realizes: Dict[UOp, UOp] = {}
graph_rewrite(big_graph, do_realize, realizes)
store_groups = get_realizes(outs, children, allbufs, double_reduces, realizes, ctx)
store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops)
# split realizes into small graphs
graph_rewrite(big_graph, break_sched, realizes)
sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups]

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard
T = TypeVar("T")
U = TypeVar("U")
@@ -320,10 +318,7 @@ class trange(tqdm):
def _reconstruct_code(*args): return types.CodeType(*args)
def _serialize_code(code:types.CodeType):
# NOTE: this works in Python 3.8 and up
if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters
else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring',
'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars']
args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up
return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args)
copyreg.pickle(types.CodeType, _serialize_code)

View File

@@ -6,7 +6,7 @@ from dataclasses import dataclass, field
from collections import defaultdict
from weakref import WeakValueDictionary
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T
if TYPE_CHECKING:
from tinygrad.shape.shapetracker import ShapeTracker
@@ -140,7 +140,7 @@ class Ops(FastEnum):
# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -168,7 +168,8 @@ class Ops(FastEnum):
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
Ops.SUB, Ops.FDIV}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
@@ -310,7 +311,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return ret
def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs)
def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx))
def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st)
def const_like(self, b:ConstLike): return UOp.const(self.dtype, b)
def broadcast(self, count:int):
assert self.dtype.count == 1
@@ -346,6 +346,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis))
def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x))
# *** uop movement ops ***
def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st)
def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg))
# *** uop Variable stuff ***
@staticmethod
@@ -379,7 +384,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is Ops.CONST: return self.arg
if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg)
if self.op is Ops.VCONST: return math.gcd(*self.arg)
if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1
return 1
@@ -630,7 +635,6 @@ class PatternMatcher:
for p,fxn in self.patterns:
assert p.op is not None
tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn)
tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends
real_fxn = types.FunctionType(*tuple_fxn)
for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters))
@@ -901,7 +905,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]:
def lt_folding(x:UOp, c:int) -> Optional[UOp]:
p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1)
if np and (d:=functools.reduce(math.gcd, [u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d:
return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d)
return None
@@ -1050,6 +1054,8 @@ symbolic = symbolic_simple+PatternMatcher([
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
# group like
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
# ** boolean algebra **
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
# ** combine terms **
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)

View File

@@ -1,50 +1,77 @@
from typing import Dict, Callable, List, Optional
from llvmlite import ir
from tinygrad.dtype import DType, PtrDType, dtypes
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp
from typing import List, Dict, cast
import math, struct
from tinygrad.renderer import Renderer
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp
from tinygrad.dtype import dtypes, DType, PtrDType, truncate
MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc
def ldt(dt:DType):
if isinstance(dt, PtrDType): return ldt(dt.base) + "*"
return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64",
dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64",
dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt]
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16),
dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64),
dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() }
def cast(bb, val, input_type, output_type, bitcast=False):
if input_type == output_type: return val
llvm_type = dtype_to_llvm_dtype[output_type]
if bitcast: return bb[-1].bitcast(val, llvm_type)
if input_type == dtypes.bfloat16:
val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType())
input_type = dtypes.float32
if output_type == dtypes.bfloat16:
val = cast(bb, val, input_type, dtypes.float32)
return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16))
def lconst(x, dtype:DType):
if dtype in dtypes.floats:
if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
return truncate[dtype](x)
return int(x)
def lcast(input_type:DType, output_type:DType):
if dtypes.is_float(input_type):
if dtypes.is_float(output_type):
return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type)
if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0))
if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc'
if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi'
if dtypes.is_unsigned(input_type) or input_type == dtypes.bool:
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type])
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0))
if dtypes.is_float(output_type): return 'uitofp'
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext'
if dtypes.is_int(input_type):
if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType())
if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type)
if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type)
if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0))
if dtypes.is_float(output_type): return 'sitofp'
if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext'
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args)
# llvm ops, lop[<dtype>][<op>]
unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem",
Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", }
signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"}
flags = " nsz arcp contract afn"
float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags}
lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}}
llvm_rewrite = PatternMatcher([
# memory load/store
(UPat(Ops.INDEX, name="x"), lambda ctx,x:
f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"),
(UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask:
f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n"
f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n"
f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n"
f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"),
(UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"),
(UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"),
# unary/binary/ternary ops
(UPat(Ops.SQRT, name="x"), lambda ctx,x:
f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"),
(UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"),
(UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"),
(UPat(Ops.WHERE, name="x"), lambda ctx,x:
f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"),
# range
(UPat(Ops.RANGE, name="x"), lambda ctx,x:
f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n"
f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n"
f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"),
(UPat(Ops.ENDRANGE, name="x"), lambda ctx,x:
f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n"
f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n"
f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"),
# if
(UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"),
(UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"),
])
class LLVMRenderer(Renderer):
device = "LLVM"
@@ -52,101 +79,64 @@ class LLVMRenderer(Renderer):
has_local = False
has_shared = False
global_max = None
code_for_op: Dict[Ops, Callable] = {
UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS),
UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS),
BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501
BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501
BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y),
BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501
BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501
BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501
BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501
BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501
BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501
TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)}
def render(self, name:str, uops:List[UOp]) -> str:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
extra_matcher = PatternMatcher([
# rewrite RECIP with FDIV
(UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))),
# rewrite cast to bool to CMPNE 0
(UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)),
# *** also in cstyle ***
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"),
lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))),
# rewrite MAX to CMPLT + WHERE
(UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])),
])
# extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order)
buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}}
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
def render(self, name: str, uops: List[UOp]) -> str:
r: Dict[UOp, str] = {}
args: List[str] = []
kernel: List[str] = []
end_lines: Dict[str, None] = {}
vc = -1
# create llvm function
func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name)
for a in func.args:
if a.type.is_pointer: a.add_attribute("noalias")
bb = [ir.IRBuilder(func.append_basic_block("entry"))]
loop_blocks: List = []
reduce_phis: List = []
lvars: Dict[Optional[UOp], ir.Instruction] = {}
for bufname,dtype in buf_to_dtype.items():
if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32))
# prealloc all assigns
acc_to_assign: Dict[UOp, UOp] = {}
for u in uops:
if u.op is Ops.ASSIGN:
vc += 1
r[u] = r[u.src[1]] = f"%assign{vc}"
assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice"
acc_to_assign[u.src[0]] = u.src[1]
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is Ops.INDEX:
lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True)
elif uop is Ops.STORE:
if len(src) > 2:
with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]])
else:
bb[-1].store(lvars[src[1]], lvars[src[0]])
elif uop is Ops.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
lvars[src[0]].add_incoming(idx_p1, bb[-1].block)
for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block)
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}")))
bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block)
# hack for defining sqrt function (TODO: can we get a transcendental for this?)
if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None
if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR):
r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}"
args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}")
elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass
elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to
elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype)
elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop
else:
if uop is Ops.RANGE:
bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}")))
bb[-2].branch(bb[-1].block)
# if it's an assign target, it's already preallocated
if u not in r:
vc += 1
r[u] = f"%v{vc}"
phis = []
for rp in reduce_phis:
incoming = lvars[rp]
lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype])
lvars[rp].add_incoming(incoming, bb[-2].block)
phis.append((rp, lvars[rp]))
# do the rendering of the llvm ir code
if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}")
kernel.append(cast(str, l))
lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}")
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is Ops.DEFINE_ACC:
lvars[u] = const(src[0].arg, dtype)
reduce_phis.append(u)
elif uop is Ops.LOAD:
if len(src) > 1:
with bb[-1].if_else(lvars[src[2]]) as (then, otherwise):
with then:
val1 = bb[-1].load(lvars[src[0]])
then_blk = bb[-1].block
with otherwise: otherwise_blk = bb[-1].block
val = bb[-1].phi(val1.type)
val.add_incoming(val1, then_blk)
val.add_incoming(lvars[src[1]], otherwise_blk)
else:
val = bb[-1].load(lvars[src[0]])
lvars[u] = val
elif uop is Ops.ASSIGN:
lvars[u] = lvars[src[1]]
# ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC
backward = src[0]
while backward.op is Ops.ASSIGN: backward = backward.src[0]
lvars[backward] = lvars[u]
elif uop in GroupOp.ALU:
lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype)
elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST)
elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]]
elif uop is Ops.CONST: lvars[u] = const(args, dtype)
else: raise RuntimeError(f"failed to render {uop}")
# generate the phi nodes for the assigns
if u.op is Ops.RANGE:
for x in acc_to_assign:
if u in x.src: # if this range is relevent for this acc
vc += 1
kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]")
r[x] = f"%acc{vc}"
bb[-1].ret_void()
return str(module)
# output the function
return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys())

View File

@@ -5,14 +5,70 @@
# it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC
from __future__ import annotations
from typing import Tuple, Optional, Dict, Any, DefaultDict
from typing import Tuple, Optional, Dict, Any, DefaultDict, List
from collections import defaultdict
import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib
from http.server import HTTPServer, BaseHTTPRequestHandler
from tinygrad.renderer import Renderer
from tinygrad.dtype import dtypes
from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing
from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions
# ***** API *****
class CloudRequest: pass
@dataclass(frozen=True)
class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702
@dataclass(frozen=True)
class BufferFree(CloudRequest): buffer_num: int # noqa: E702
@dataclass(frozen=True)
class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702
@dataclass(frozen=True)
class CopyOut(CloudRequest): buffer_num: int
@dataclass(frozen=True)
class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702
@dataclass(frozen=True)
class ProgramExec(CloudRequest):
name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702
global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702
# for safe deserialization
whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]}
eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)),
ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}),
ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]}
def safe_eval(node): return eval_fxns[node.__class__](node)
class BatchRequest:
def __init__(self):
self._q: List[CloudRequest] = []
self._h: Dict[str, bytes] = {}
def h(self, d:bytes) -> str:
binhash = hashlib.sha256(d).digest()
self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack("<Q", len(d))+d
return datahash
def q(self, x:CloudRequest): self._q.append(x)
def serialize(self) -> bytes:
self.h(repr(self._q).encode())
return b''.join(self._h.values())
def deserialize(self, dat:bytes) -> BatchRequest:
ptr = 0
while ptr < len(dat):
datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("<Q", dat[ptr+0x20:ptr+0x28])[0]
self._h[datahash] = dat[ptr+0x28:ptr+0x28+datalen]
ptr += 0x28+datalen
self._q = safe_eval(ast.parse(self._h[datahash], mode="eval").body)
return self
# ***** backend *****
@@ -21,7 +77,6 @@ class CloudSession:
programs: Dict[Tuple[str, str], Any] = field(default_factory=dict)
# TODO: the buffer should track this internally
buffers: Dict[int, Tuple[Any, int, Optional[BufferOptions]]] = field(default_factory=dict)
buffer_num = 0
class CloudHandler(BaseHTTPRequestHandler):
protocol_version = 'HTTP/1.1'
@@ -32,66 +87,47 @@ class CloudHandler(BaseHTTPRequestHandler):
super().setup()
print(f"connection established with {self.client_address}, socket: {self.connection.fileno()}")
def get_data(self):
content_len = self.headers.get('Content-Length')
assert content_len is not None
return self.rfile.read(int(content_len))
def get_json(self): return json.loads(self.get_data())
def _fail(self):
self.send_response(404)
self.end_headers()
return 0
def _do(self, method):
session = CloudHandler.sessions[unwrap(self.headers.get("Cookie")).split("session=")[1]]
ret = b""
if self.path == "/renderer" and method == "GET":
ret, status_code = b"", 200
if self.path == "/batch" and method == "POST":
# TODO: streaming deserialize?
req = BatchRequest().deserialize(self.rfile.read(int(unwrap(self.headers.get('Content-Length')))))
# the cmds are always last (currently in datahash)
for c in req._q:
if DEBUG >= 1: print(c)
match c:
case BufferAlloc():
assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated"
session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options)
case BufferFree():
buf,sz,buffer_options = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
del session.buffers[c.buffer_num]
case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash])))
case CopyOut():
buf,sz,_ = session.buffers[c.buffer_num]
Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
case ProgramAlloc():
lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode())
session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib)
case ProgramFree(): del session.programs[(c.name, c.datahash)]
case ProgramExec():
bufs = [session.buffers[x][0] for x in c.bufs]
extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None}
r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args)
if r is not None: ret = str(r).encode()
elif self.path == "/renderer" and method == "GET":
cls, args = Device[CloudHandler.dname].renderer.__reduce__()
ret = json.dumps((cls.__module__, cls.__name__, args)).encode()
elif self.path.startswith("/alloc") and method == "POST":
size = int(self.path.split("=")[-1])
buffer_options: Optional[BufferOptions] = None
if 'image' in self.path:
image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")])
buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape))
session.buffer_num += 1
session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options)
ret = str(session.buffer_num).encode()
elif self.path.startswith("/buffer"):
key = int(self.path.split("/")[-1])
buf,sz,buffer_options = session.buffers[key]
if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf)
elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data())))
elif method == "DELETE":
Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options)
del session.buffers[key]
else: return self._fail()
elif self.path.startswith("/program"):
name, hsh = self.path.split("/")[-2:]
if method == "PUT":
src = self.get_data()
assert hashlib.sha256(src).hexdigest() == hsh
lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode())
session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib)
elif method == "POST":
j = self.get_json()
bufs = [session.buffers[x][0] for x in j['bufs']]
del j['bufs']
r = session.programs[(name, hsh)](*bufs, **j)
if r is not None: ret = str(r).encode()
elif method == "DELETE": del session.programs[(name, hsh)]
else: return self._fail()
else: return self._fail()
self.send_response(200)
else: status_code = 404
self.send_response(status_code)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
def do_GET(self): return self._do("GET")
def do_POST(self): return self._do("POST")
def do_PUT(self): return self._do("PUT")
def do_DELETE(self): return self._do("DELETE")
def cloud_server(port:int):
multiprocessing.current_process().name = "MainProcess"
@@ -106,44 +142,46 @@ class CloudAllocator(Allocator):
def __init__(self, device:CloudDevice):
self.device = device
super().__init__()
def _alloc(self, size:int, options) -> int:
# TODO: ideally we shouldn't have to deal with images here
extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else ""
return int(self.device.send("POST", f"alloc?{extra}size={size}"))
def _free(self, opaque, options):
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected):
self.device.send("DELETE", f"buffer/{opaque}", data=b"")
def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src))
# TODO: ideally we shouldn't have to deal with images here
def _alloc(self, size:int, options:BufferOptions) -> int:
self.device.buffer_num += 1
self.device.req.q(BufferAlloc(self.device.buffer_num, size, options))
return self.device.buffer_num
# TODO: options should not be here in any Allocator
def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque))
def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src))))
def copyout(self, dest:memoryview, src:int):
resp = self.device.send("GET", f"buffer/{src}")
self.device.req.q(CopyOut(src))
resp = self.device.batch_submit()
assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}"
dest[:] = resp
class CloudProgram:
def __init__(self, device:CloudDevice, name:str, lib:bytes):
self.device = device
self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}"
self.device.send("PUT", "program/"+self.prgid, lib)
self.device, self.name = device, name
self.datahash = self.device.req.h(lib)
self.device.req.q(ProgramAlloc(self.name, self.datahash))
super().__init__()
def __del__(self): self.device.send("DELETE", "program/"+self.prgid)
def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash))
def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False):
args = {"bufs": bufs, "vals": vals, "wait": wait}
if global_size is not None: args["global_size"] = global_size
if local_size is not None: args["local_size"] = local_size
ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode())
if wait: return float(ret)
self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait))
if wait: return float(self.device.batch_submit())
class CloudDevice(Compiled):
def __init__(self, device:str):
if (host:=getenv("HOST", "")) != "":
self.host = host
if (host:=getenv("HOST", "")) != "": self.host = host
else:
p = multiprocessing.Process(target=cloud_server, args=(6667,))
p.daemon = True
p.start()
self.host = "127.0.0.1:6667"
self.cookie = binascii.hexlify(os.urandom(0x10)).decode()
# state for the connection
self.session = binascii.hexlify(os.urandom(0x10)).decode()
self.buffer_num = 0
self.req: BatchRequest = BatchRequest()
if DEBUG >= 1: print(f"cloud with host {self.host}")
while 1:
try:
@@ -155,13 +193,26 @@ class CloudDevice(Compiled):
time.sleep(0.1)
if DEBUG >= 1: print(f"remote has device {clouddev}")
# TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer
assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}"
renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2])
super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self))
if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}")
renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure?
if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}")
super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self))
def __del__(self):
# TODO: this is never being called
# TODO: should close the whole session
with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit()
def batch_submit(self):
data = self.req.serialize()
with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1):
ret = self.send("POST", "batch", data)
self.req = BatchRequest()
return ret
def send(self, method, path, data:Optional[bytes]=None) -> bytes:
# TODO: retry logic
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"})
self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"})
response = self.conn.getresponse()
assert response.status == 200, f"failed on {method} {path}"
return response.read()

View File

@@ -2,18 +2,24 @@ from __future__ import annotations
import ctypes, functools
from typing import Tuple
from tinygrad.device import Compiled, Compiler, MallocAllocator
from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump
from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump, getenv
from tinygrad.renderer.llvmir import LLVMRenderer
import llvmlite.binding as llvm
class LLVMCompiler(Compiler):
def __init__(self, device:LLVMDevice):
def __init__(self, device:LLVMDevice, opt:bool=False):
self.device = device
super().__init__("compile_llvm")
self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
self.device.target_machine.add_analysis_passes(self.optimizer)
if opt:
with llvm.create_pass_manager_builder() as builder:
builder.opt_level = 3; builder.size_level = 0; builder.loop_vectorize = True; builder.slp_vectorize = True # noqa: E702
builder.populate(self.optimizer)
super().__init__("compile_llvm_opt" if opt else "compile_llvm")
def compile(self, src:str) -> bytes:
mod = llvm.parse_assembly(src)
mod.verify()
self.device.optimizer.run(mod)
self.optimizer.run(mod)
if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod))
return self.device.target_machine.emit_object(mod)
@@ -23,6 +29,7 @@ class LLVMProgram:
self.name, self.lib = name, lib
device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = device.engine.get_function_address(name)
assert self.fxn != 0, "LLVM failed to get function address"
def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False):
if not hasattr(self, 'cfunc'):
@@ -35,12 +42,9 @@ class LLVMDevice(Compiled):
llvm.initialize_native_target()
llvm.initialize_native_asmprinter()
llvm.initialize_native_asmparser()
self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager()
# this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA
self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2)
self.target_machine.add_analysis_passes(self.optimizer)
self.target_machine.set_asm_verbosity(True)
backing_mod = llvm.parse_assembly(str())
backing_mod.triple = llvm.get_process_triple()
self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine)
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self))
super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self, getenv("LLVMOPT")), functools.partial(LLVMProgram, self))

View File

@@ -155,7 +155,7 @@ class QCOMComputeQueue(HWComputeQueue):
if args_state.prg.tex_cnt > 0:
self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT,
state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt),
state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)),
*data64_le(args_state.ptr + args_state.prg.tex_off))
self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off))
@@ -247,14 +247,15 @@ class QCOMProgram(HCQProgram):
self.buf_info, self.consts_info = [], []
# Collect sampler info.
self.samp_cnt = _read_lib(image_desc_off + 0xdc)
self.samp_cnt = samp_cnt_in_file = _read_lib(image_desc_off + 0xdc)
assert self.samp_cnt <= 1, "Up to one sampler supported"
if self.samp_cnt:
self.samp_cnt += 1
self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode),
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0]
qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0, 0, 0, 0, 0]
# Collect kernel arguments (buffers) info.
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt
bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samp_cnt_in_file
while bdoff + 32 <= len(self.lib):
length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32])
if length == 0: break
@@ -263,7 +264,7 @@ class QCOMProgram(HCQProgram):
# Setting correct offsets to textures/ibos.
self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info)
self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt
self.ibo_off, self.tex_off, self.samp_off = 2048, 2048 + 0x40 * self.ibo_cnt, 2048 + 0x40 * self.tex_cnt + 0x40 * self.ibo_cnt
cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off
for x in self.buf_info:
if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40
@@ -307,10 +308,10 @@ class QCOMAllocator(HCQAllocator):
texture.pitch, texture.real_stride = pitch, real_stride
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh)
texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6)
texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)]
texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)]
texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]]
return texture

View File

@@ -464,9 +464,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {}
@staticmethod
def _threefry_random_bits(key, counts0, counts1):
def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor):
x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64)
x = F.Threefry.apply(*x._broadcasted(key))
x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64))
counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32)
return counts0.cat(counts1)
@@ -494,9 +494,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
# generate per device seeds and rng counter if we haven't seen this device yet
if device not in Tensor._device_seeds:
Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \
| int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff],
device=device, dtype=dtypes.uint64, requires_grad=False)
Tensor._device_seeds[device] = Tensor(
[int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed],
device=device, dtype=dtypes.uint32, requires_grad=False)
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
had_counter = False
else: had_counter = True
@@ -612,6 +612,26 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs)
return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype)
@staticmethod
def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor:
"""
Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive.
You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor.
Additionally, all other keyword arguments are passed to the constructor of the tensor.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(0, 10, 5).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(Tensor.linspace(-1, 1, 5).numpy())
```
"""
if steps < 0: raise ValueError("number of steps must be non-negative")
if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported")
if steps == 1: return Tensor([start], dtype=dtype, **kwargs)
return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype)
@staticmethod
def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor:
"""
@@ -734,7 +754,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
```
"""
if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers")
dtype = kwargs.pop("dtype", dtypes.int32)
dtype = to_dtype(kwargs.pop("dtype", dtypes.int32))
if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int")
return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)