diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index fd1e6dd720..5b20d38067 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -47,55 +47,55 @@ jobs: echo "CACHEDB=/tmp/staging.db" >> $GITHUB_ENV rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal - name: reset process replay - run: test/external/process_replay/reset.py + run: python3.11 test/external/process_replay/reset.py - name: Run Stable Diffusion - run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt + run: JIT=2 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt - name: Run Stable Diffusion with fp16 - run: JIT=2 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt + run: JIT=2 python3.11 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd_fp16.txt - name: Run SDXL - run: JIT=2 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt + run: JIT=2 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt - name: Run model inference benchmark - run: METAL=1 python3 test/external/external_model_benchmark.py + run: METAL=1 python3.11 test/external/external_model_benchmark.py - name: Test speed vs torch - run: BIG=2 MPS=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt + run: BIG=2 MPS=1 python3.11 test/test_speed_v_torch.py | tee torch_speed.txt - name: Test tensor cores - run: METAL=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded + run: METAL=1 python3.11 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded - name: Run Tensor Core GEMM run: | - DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul.txt - DEBUG=2 HALF=1 python3 extra/gemm/simple_matmul.py | tee matmul_half.txt + DEBUG=2 python3.11 extra/gemm/simple_matmul.py | tee matmul.txt + DEBUG=2 HALF=1 python3.11 extra/gemm/simple_matmul.py | tee matmul_half.txt - name: Fuzz Padded Tensor Core GEMM - run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py + run: METAL=1 M_START=6 M_STOP=10 M_STEP=1 N_START=6 N_STOP=10 N_STEP=1 K_START=6 K_STOP=24 K_STEP=1 TC_OPT=2 DEBUG=2 python3.11 ./extra/gemm/fuzz_matmul.py - name: Run LLaMA run: | - JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt - JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt + JIT=0 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt + JIT=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt - name: Run LLaMA with BEAM - run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt + run: JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt - name: Run quantized LLaMA run: | - python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt - python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt + python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt + python3.11 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize nf4 | tee llama_nf4.txt - name: Run LLaMA 7B on 4 (virtual) GPUs - run: python3 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt + run: python3.11 examples/llama.py --gen 1 --size 7B --shard 4 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_four_gpu.txt - name: Run GPT2 run: | - JIT=0 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt - JIT=1 python3 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt + JIT=0 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_unjitted.txt + JIT=1 python3.11 examples/gpt2.py --prompt "Hello." --count 10 --temperature 0 --timing | tee gpt2_jitted.txt - name: Run GPT2 w HALF - run: HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt + run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt - name: Run GPT2 w HALF/BEAM - run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt + run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt - name: Train MNIST - run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt + run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps - run: JIT=2 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar.txt + run: JIT=2 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar.txt - name: Run 10 CIFAR training steps w HALF - run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3 examples/hlb_cifar10.py | tee train_cifar_half.txt + run: JIT=2 STEPS=10 DEFAULT_FLOAT=HALF python3.11 examples/hlb_cifar10.py | tee train_cifar_half.txt #- name: Run 10 CIFAR training steps w BF16 - # run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt + # run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3.11 examples/hlb_cifar10.py | tee train_cifar_bf16.txt - name: Run 10 CIFAR training steps w winograd - run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt + run: JIT=2 WINO=1 STEPS=10 python3.11 examples/hlb_cifar10.py | tee train_cifar_wino.txt - uses: actions/upload-artifact@v4 with: name: Speed (Mac) @@ -123,7 +123,7 @@ jobs: train_cifar_bf16.txt train_cifar_wino.txt - name: Run process replay tests - run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py + run: cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3.11 process_replay.py testnvidiabenchmark: name: tinybox green Benchmark @@ -157,7 +157,7 @@ jobs: - name: Run model inference benchmark run: NV=1 RUN_PROCESS_REPLAY=0 NOCLANG=1 python3 test/external/external_model_benchmark.py - name: Test speed vs torch - run: NV=1 RUN_PROCESS_REPLAY=0 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt + run: NV=1 RUN_PROCESS_REPLAY=0 HALF=1 BIG=2 TORCHCUDA=1 python3 test/test_speed_v_torch.py | tee torch_speed.txt - name: Test tensor cores run: | NV=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded @@ -170,8 +170,10 @@ jobs: run: NV=1 PTX=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_ptx.txt - name: Run Tensor Core GEMM (NV) run: NV=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_nv.txt - - name: Run Tensor Core GEMM (NV) with BEAM - run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py + # - name: Run Tensor Core GEMM (NV) with BEAM + # run: BEAM=4 NV=1 HALF=1 IGNORE_BEAM_CACHE=1 DEBUG=2 python3 extra/gemm/simple_matmul.py + - name: Test speed vs theoretical + run: NV=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py - name: Run Stable Diffusion run: NV=1 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt - name: Run SDXL @@ -328,6 +330,10 @@ jobs: rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal - name: reset process replay run: test/external/process_replay/reset.py + - name: setup perflevel + run: | + examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh + rocm-smi - name: Show off tinybox run: /opt/rocm/bin/rocm-bandwidth-test # TODO: unstable on AMD @@ -343,6 +349,8 @@ jobs: AMD=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded - name: Run Tensor Core GEMM (AMD) run: AMD=1 HALF=1 DEBUG=2 python3 extra/gemm/simple_matmul.py | tee matmul_amd.txt + - name: Test speed vs theoretical + run: AMD=1 IGNORE_BEAM_CACHE=1 DEBUG=1 python -m pytest -rA test/external/speed_v_theoretical.py # TODO: AMD compiler bug causes this to fail #- name: Fuzz Padded Tensor Core GEMM # run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py @@ -431,6 +439,10 @@ jobs: rm -f /tmp/staging.db /tmp/staging.db-shm /tmp/staging.db-wal - name: reset process replay run: test/external/process_replay/reset.py + - name: setup perflevel + run: | + examples/mlperf/training_submission_v4.1/tinycorp/benchmarks/bert/implementations/tinybox_red/setup.sh + rocm-smi - name: Train MNIST run: time PYTHONPATH=. AMD=1 TARGET_EVAL_ACC_PCT=96.0 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ae74053fe..97b0264ea1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -209,15 +209,15 @@ jobs: steps: - name: Checkout Code uses: actions/checkout@v4 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v5 with: - python-version: 3.8 + python-version: "3.10" - name: Cache python packages uses: actions/cache@v4 with: - path: ${{ env.Python3_ROOT_DIR }}/lib/python3.8/site-packages - key: linting-packages-${{ hashFiles('**/setup.py') }}-3.8 + path: ${{ env.Python3_ROOT_DIR }}/lib/python3.10/site-packages + key: linting-packages-${{ hashFiles('**/setup.py') }}-3.10 - name: Install dependencies run: pip install -e '.[linting,testing]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Lint bad-indentation and trailing-whitespace with pylint diff --git a/docs/env_vars.md b/docs/env_vars.md index efc13c072f..d0376a9e26 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -46,4 +46,5 @@ FLOAT16 | [1] | use float16 for images instead of float32 PTX | [1] | enable the specialized [PTX](https://docs.nvidia.com/cuda/parallel-thread-execution/) assembler for Nvidia GPUs. If not set, defaults to generic CUDA codegen backend. PROFILE | [1] | enable output of [perfetto](https://ui.perfetto.dev/) compatible profile. This feature is supported in NV and AMD backends. VISIBLE_DEVICES | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0). -JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled \ No newline at end of file +JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled +VIZ | [1] | 0=disabled, 1=[viz enabled](../tinygrad/viz/README) \ No newline at end of file diff --git a/docs/tensor/creation.md b/docs/tensor/creation.md index 897c29ad38..722d1a41e0 100644 --- a/docs/tensor/creation.md +++ b/docs/tensor/creation.md @@ -5,6 +5,7 @@ ::: tinygrad.Tensor.ones ::: tinygrad.Tensor.full ::: tinygrad.Tensor.arange +::: tinygrad.Tensor.linspace ::: tinygrad.Tensor.eye ::: tinygrad.Tensor.full_like ::: tinygrad.Tensor.zeros_like diff --git a/extra/mockgpu/mockgpu.py b/extra/mockgpu/mockgpu.py index 5374feb7f9..4a118dce91 100644 --- a/extra/mockgpu/mockgpu.py +++ b/extra/mockgpu/mockgpu.py @@ -178,7 +178,7 @@ class TrackedMemoryView: def __len__(self): return len(self.mv) def __repr__(self): return repr(self.mv) -def _memoryview(mem): +def _memoryview(cls, mem): if isinstance(mem, int) or isinstance(mem, ctypes.Array): addr = ctypes.addressof(mem) if isinstance(mem, ctypes.Array) else mem for d in drivers: @@ -196,7 +196,7 @@ install_hook(libc.lseek64, _lseek64) install_hook(libc.stat64, _stat64) install_hook(libc.fstat64, _fstat64) install_hook(libc.getdents64, _getdents64) -builtins.memoryview = _memoryview # type: ignore +builtins.memoryview = type("memoryview", (), {'__new__': _memoryview}) # type: ignore # rewrite autogen's libc mmaps functions. import tinygrad.runtime.autogen.libc as autogen_libc diff --git a/extra/models/unet.py b/extra/models/unet.py index 92d4496320..94d2e95410 100644 --- a/extra/models/unet.py +++ b/extra/models/unet.py @@ -1,6 +1,6 @@ from tinygrad import Tensor, dtypes from tinygrad.nn import Linear, Conv2d, GroupNorm, LayerNorm - +from tinygrad.device import is_dtype_supported from typing import Optional, Union, List, Any, Tuple import math @@ -9,7 +9,8 @@ def timestep_embedding(timesteps:Tensor, dim:int, max_period=10000): half = dim // 2 freqs = (-math.log(max_period) * Tensor.arange(half, device=timesteps.device) / half).exp() args = timesteps.unsqueeze(1) * freqs.unsqueeze(0) - return Tensor.cat(args.cos(), args.sin(), dim=-1).cast(dtypes.float16) + out = Tensor.cat(args.cos(), args.sin(), dim=-1) + return out.cast(dtypes.float16) if is_dtype_supported(dtypes.float16) else out class ResBlock: def __init__(self, channels:int, emb_channels:int, out_channels:int): @@ -222,16 +223,17 @@ class UNetModel: ] def __call__(self, x:Tensor, tms:Tensor, ctx:Tensor, y:Optional[Tensor]=None) -> Tensor: - t_emb = timestep_embedding(tms, self.model_ch).cast(dtypes.float16) + t_emb = timestep_embedding(tms, self.model_ch) emb = t_emb.sequential(self.time_embed) if y is not None: assert y.shape[0] == x.shape[0] emb = emb + y.sequential(self.label_emb[0]) - emb = emb.cast(dtypes.float16) - ctx = ctx.cast(dtypes.float16) - x = x .cast(dtypes.float16) + if is_dtype_supported(dtypes.float16): + emb = emb.cast(dtypes.float16) + ctx = ctx.cast(dtypes.float16) + x = x .cast(dtypes.float16) def run(x:Tensor, bb) -> Tensor: if isinstance(bb, ResBlock): x = bb(x, emb) diff --git a/extra/onnx.py b/extra/onnx.py index de5925ce27..fd96387f53 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -7,6 +7,7 @@ from tinygrad import Tensor, dtypes, Device from tinygrad.tensor import _to_np_dtype from tinygrad.helpers import getenv, DEBUG, CI, OSX from tinygrad.dtype import ConstType, DType +from tinygrad.device import is_dtype_supported from onnx import AttributeProto, ModelProto, TensorProto, TypeProto try: from onnx.helper import tensor_dtype_to_np_dtype @@ -29,14 +30,6 @@ def to_python_const(t, tobytes=False) -> Union[List[ConstType], List[bytes], Uni cache_misses = info.misses return ret -# copied from helpers.py -def is_dtype_supported(dtype, device: str = Device.DEFAULT): - if dtype == dtypes.bfloat16: return False - if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] - if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"}) - if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") - return True - # src: onnx/mapping.py https://onnx.ai/onnx/api/mapping.html#l-mod-onnx-mapping # not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15, UINT4 = 21, INT4 = 22 # TODO: use dtypes.float16 for FLOAT16 diff --git a/extra/qcom_gpu_driver/opencl_ioctl.py b/extra/qcom_gpu_driver/opencl_ioctl.py index 0e87a31789..ea1896bd01 100644 --- a/extra/qcom_gpu_driver/opencl_ioctl.py +++ b/extra/qcom_gpu_driver/opencl_ioctl.py @@ -1,6 +1,7 @@ # type: ignore import ctypes, ctypes.util, struct, fcntl, re from hexdump import hexdump +from copy import deepcopy import pathlib, sys from tinygrad.helpers import to_mv, getenv from tinygrad.runtime.autogen import adreno @@ -91,27 +92,36 @@ def parse_cmd_buf(dat): num_unit = vals[0]>>22 if IOCTL > 0: print(f"{num_unit=} {state_block=} {state_src=} {state_type=} {dst_off=}") - if state_block == SB6_CS_SHADER and IOCTL > 2: + if "LOAD_FRAGS" not in CAPTURED_STATE: CAPTURED_STATE['LOAD_FRAGS'] = [] + CAPTURED_STATE['LOAD_FRAGS'].append((state_block, state_type, num_unit, dst_off)) + + if state_block == SB6_CS_SHADER: from extra.disassemblers.adreno import disasm_raw - if state_type == ST6_SHADER: disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128)) - if state_type == ST6_CONSTANTS: hexdump(get_mem(((vals[2] << 32) | vals[1]), min(0x180, num_unit*4))) + if state_type == ST6_SHADER and IOCTL > 2: + disasm_raw(get_mem(((vals[2] << 32) | vals[1]), num_unit * 128)) + if state_type == ST6_CONSTANTS: + x = get_mem(((vals[2] << 32) | vals[1]), num_unit*4) + CAPTURED_STATE['constants'] = x[:] + if IOCTL > 2: + print('constants') + hexdump(x) if state_type == ST6_IBO: ibos_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4) CAPTURED_STATE['ibos'] = ibos_bytes[:] - if IOCTL > 0: + if IOCTL > 1: print('texture ibos') hexdump(ibos_bytes) - elif state_block == SB6_CS_TEX and IOCTL > 2: + elif state_block == SB6_CS_TEX: if state_type == ST6_SHADER: samplers_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 4 * 4) - CAPTURED_STATE['samplers'] = ibos_bytes[:] - if IOCTL > 0: + CAPTURED_STATE['samplers'] = samplers_bytes[:] + if IOCTL > 1: print('texture samplers') hexdump(samplers_bytes) if state_type == ST6_CONSTANTS: - descriptors_bytes = get_mem((vals[2] << 32) | vals[1], num_unit * 16 * 4) - CAPTURED_STATE['descriptors'] = ibos_bytes[:] - if IOCTL > 0: + descriptors_bytes = get_mem((vals[2] << 32) | vals[1], 1600) + CAPTURED_STATE['descriptors'] = descriptors_bytes[:] + if IOCTL > 1: print('texture descriptors') hexdump(descriptors_bytes) @@ -200,10 +210,19 @@ install_hook(libc.ioctl, ioctl) def before_launch(): global CAPTURED_STATE CAPTURED_STATE.clear() -def collect_last_launch_state(): return CAPTURED_STATE +def collect_last_launch_state(): + global CAPTURED_STATE + return deepcopy(CAPTURED_STATE) def compare_launch_state(state, good_state): cmp = [ - (adreno.REG_A6XX_SP_CS_CONFIG, 0xffffffff), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NTEX__MASK), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NSAMP__MASK), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_NIBO__MASK), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_ENABLED), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_TEX), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_SAMP), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_IBO), + (adreno.REG_A6XX_SP_CS_CONFIG, adreno.A6XX_SP_CS_CONFIG_BINDLESS_UBO), (adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_HALFREGFOOTPRINT__MASK), (adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_FULLREGFOOTPRINT__MASK), @@ -213,20 +232,53 @@ def compare_launch_state(state, good_state): (adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_EARLYPREAMBLE), (adreno.REG_A6XX_SP_CS_CTRL_REG0, adreno.A6XX_SP_CS_CTRL_REG0_MERGEDREGS), + (adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_MEMSIZEPERITEM__MASK), + (adreno.REG_A6XX_SP_CS_PVT_MEM_PARAM, adreno.A6XX_SP_CS_PVT_MEM_PARAM_HWSTACKSIZEPERTHREAD__MASK), + (adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK5), (adreno.REG_A6XX_SP_CS_UNKNOWN_A9B1, adreno.A6XX_SP_CS_UNKNOWN_A9B1_UNK6), (adreno.REG_A6XX_SP_CS_BRANCH_COND, 0xffffffff), + + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_KERNELDIM__MASK), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEX__MASK), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEY__MASK), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_0, adreno.A6XX_HLSQ_CS_NDRANGE_0_LOCALSIZEZ__MASK), + + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_1, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_2, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_3, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_4, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_5, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_NDRANGE_6, 0xffffffff), + + (adreno.REG_A6XX_HLSQ_CS_CNTL_0, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_CNTL_1, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_X, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Y, 0xffffffff), + (adreno.REG_A6XX_HLSQ_CS_KERNEL_GROUP_Z, 0xffffffff), ] for x,m in cmp: + print(f"Field {REGS[x]}, mask: 0x{m:X} cmp: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}") if state.get(x, 0) & m != good_state.get(x, 0) & m: - return False, f"Field {REGS[x]}, mask: {x:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}" + return False, f"Field {REGS[x]}, mask: 0x{m:X} mismatch: {state.get(x, 0) & m} vs {good_state.get(x, 0) & m}" - for n in ['ibos', 'samplers', 'descriptors']: + for n in ['descriptors', 'ibos']: if n not in good_state: continue mv1, mv2 = state.get(n), good_state.get(n) - if len(mv1) != len(mv2): return False, f"{n}: len mismatch" + + if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}" + mv1 = memoryview(bytearray(mv1)).cast('I') + mv2 = memoryview(bytearray(mv2)).cast('I') + for i in range(len(mv2)): + if i % 8 == 5 or i % 8 == 4: continue # addresses + if mv1[i]!=mv2[i]: return False, f"{n}: content mismatch {i} {mv1[i]} {mv2[i]}" + + for n in ['samplers']: + if n not in good_state: continue + mv1, mv2 = state.get(n), good_state.get(n) + if len(mv1) != len(mv2): return False, f"{n}: len mismatch {len(mv1)} != {len(mv2)}" if any(mv1[i]!=mv2[i] for i in range(len(mv1))): return False, f"{n}: content mismatch" return True, "PASS" diff --git a/setup.py b/setup.py index 1ea6e394a1..49c3e9db81 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setup(name='tinygrad', "License :: OSI Approved :: MIT License" ], install_requires=[], - python_requires='>=3.8', + python_requires='>=3.10', extras_require={ 'llvm': ["llvmlite"], 'arm': ["unicorn"], diff --git a/test/Dockerfile b/test/Dockerfile index 22be7b4a0c..260a0b0bf9 100644 --- a/test/Dockerfile +++ b/test/Dockerfile @@ -1,8 +1,8 @@ -FROM ubuntu:20.04 +FROM ubuntu:22.04 -# Install python3.8, and pip3 +# Install python3.10, and pip3 RUN apt-get update && apt-get install -y --no-install-recommends \ - python3.8 \ + python3.10 \ python3-pip \ && rm -rf /var/lib/apt/lists/* diff --git a/test/external/external_test_hcq_fuzz_failures.py b/test/external/external_test_hcq_fuzz_failures.py new file mode 100644 index 0000000000..9370b7661d --- /dev/null +++ b/test/external/external_test_hcq_fuzz_failures.py @@ -0,0 +1,62 @@ +# ruff: noqa: E501 +import os +os.environ["VALIDATE_HCQ"]="1" + +import unittest, random +import numpy as np +from tinygrad.codegen.kernel import Kernel, KernelOptError +from tinygrad.device import is_dtype_supported +from tinygrad.ops import UOp, Ops +from tinygrad.engine.search import Opt, OptOps +from tinygrad import Device, dtypes, Tensor +from test.external.fuzz_linearizer import compare_linearizer, compare_states, get_fuzz_rawbuf_like + +from tinygrad.shape.shapetracker import ShapeTracker +from tinygrad.shape.view import View + +def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1e-2, atol=1e-2): + if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return + if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return + + for opt in opts: + try: + lin.apply_opt(opt) + except KernelOptError: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" + return + + (msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(lin, rtol=rtol, atol=atol) + if msg in ["PASS", "KernelOptError"]: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" + else: + assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {msg}" + + validate_lin = lin.copy() + validate_lin.opts = validate_device.renderer + validate_rawbufs = [get_fuzz_rawbuf_like(x, copy=True, force_device=validate_device.dname) for x in rawbufs] + (_msg, _, _, _, state2) = compare_linearizer(validate_lin, validate_rawbufs, var_vals, ground_truth, rtol=rtol, atol=atol) + + if _msg in ["PASS"] and compare_states(state1, state2): + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" + else: + assert Device.DEFAULT in failed_platforms, f"failed on {Device.DEFAULT} with {msg}" + + return lin + +class TestHCQFuzzFailures(unittest.TestCase): + def setUp(self): + random.seed(42) + np.random.seed(42) + Tensor.manual_seed(42) + + @unittest.skipUnless(Device.DEFAULT in {"QCOM"}, "for QCOM") + def test_failure_1(self): + ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=1, src=()), x39:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=0, mask=((0, 1), (0, 6)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), x39,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=3, src=()), x46:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-6, mask=((0, 1), (6, 12)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), x46,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=5, src=()), x54:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (12, 13)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), x54,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-13, mask=((0, 1), (13, 17)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-17, mask=((0, 1), (17, 21)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=9, src=()), x68:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (21, 22)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), x68,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-22, mask=((0, 1), (22, 26)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-26, mask=((0, 1), (26, 30)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=13, src=()), x82:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (30, 31)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), x82,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=15, src=()), x90:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (31, 32)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), x90,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=17, src=()), x98:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (32, 33)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), x98,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=19, src=()), x106:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (33, 34)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=20, src=()), x106,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=21, src=()), x114:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (34, 35)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=22, src=()), x114,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=23, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-35, mask=((0, 1), (35, 39)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=24, src=()), x125:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (39, 40)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=25, src=()), x125,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=26, src=()), x133:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (40, 41)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=27, src=()), x133,)),)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=28, src=()), x140:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-41, mask=((0, 1), (41, 47)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=29, src=()), x140,)),)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 2, 4)), arg=30, src=()), x147:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-47, mask=((0, 1), (47, 53)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=31, src=()), x147,)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=32, src=()), x155:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (53, 54)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=33, src=()), x155,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=34, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-54, mask=((0, 1), (54, 58)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=35, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-58, mask=((0, 1), (58, 62)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=36, src=()), x169:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (62, 63)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=37, src=()), x169,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=38, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-63, mask=((0, 1), (63, 67)), contiguous=False),)), src=()),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=39, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-67, mask=((0, 1), (67, 71)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=40, src=()), x183:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (71, 72)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=41, src=()), x183,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=42, src=()), x191:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (72, 73)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=43, src=()), x191,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=44, src=()), x199:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (73, 74)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=45, src=()), x199,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=46, src=()), x207:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (74, 75)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=47, src=()), x207,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=48, src=()), x215:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (75, 76)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=49, src=()), x215,)),)),)),)),)), UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=50, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-76, mask=((0, 1), (76, 80)), contiguous=False),)), src=()),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=51, src=()), x226:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (80, 81)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=52, src=()), x226,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=53, src=()), x234:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (81, 82)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=54, src=()), x234,)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=55, src=()), x243:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (82, 83)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=56, src=()), x243,)),)),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 1, 4)), arg=57, src=()), x250:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 0), offset=0, mask=((0, 1), (83, 84)), contiguous=False),)), src=()),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=58, src=()), x250,)),)),)),)),)),)), UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.imageh((1, 128, 4)), arg=59, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 596), strides=(0, 1), offset=-84, mask=((0, 1), (84, 596)), contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 + + opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] + helper_test_lin(Kernel(ast), opts, failed_platforms=[], validate_device=Device["GPU"]) + +if __name__ == '__main__': + unittest.main() diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 3da89aa72b..296ec2c60e 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -30,7 +30,7 @@ from tinygrad.device import is_dtype_supported def on_linearizer_will_run(): pass def on_linearizer_did_run(): pass -def compare_states(x, y): return True +def compare_states(x, y): return (True, "") if getenv("VALIDATE_HCQ"): if Device.DEFAULT == "NV": diff --git a/test/external/speed_v_theoretical.py b/test/external/speed_v_theoretical.py new file mode 100644 index 0000000000..83ec9d81f7 --- /dev/null +++ b/test/external/speed_v_theoretical.py @@ -0,0 +1,51 @@ +import unittest, time +from tinygrad import Tensor, TinyJit, Device +from tinygrad.helpers import Context, DEBUG + +class TestKernelSpeed(unittest.TestCase): + def _test_matmul(self, M, N=None, K=None, nv=None, amd=None): + # (MxK) @ (KxN) + @TinyJit + def f(a, b): return (a @ b).realize() + + if N is None: N = M + if K is None: K = M + tms = [] + with Context(BEAM=3): + for _ in range(10): + with Context(BEAM=0, DEBUG=0): + a = Tensor.rand(M, K, dtype="half").realize() + b = Tensor.rand(K, N, dtype="half").realize() + Device.default.synchronize() + st = time.perf_counter() + _c = f(a, b) + Device.default.synchronize() + tms.append(time.perf_counter() - st) + + ops = 2 * M * N * K + tm = min(tms) + tflops = ops / tm / 1e12 + + if DEBUG >= 1: + print(f"{tm=}") + print(f"{tflops=}") + + if Device.DEFAULT == "NV": + if DEBUG >=1: print(f"target: {nv}") + self.assertGreater(tflops, nv) + if Device.DEFAULT == "AMD": + if DEBUG >=1: print(f'target: {amd}') + self.assertGreater(tflops, amd) + + # TODO: smaller ones has other overhead in synchronize + # TODO: AMD number can be better (perf level?) + def test_gemm_1024(self): self._test_matmul(1024, nv=8, amd=7) + def test_gemm_2048(self): self._test_matmul(2048, nv=50, amd=30) + def test_gemm_4096(self): self._test_matmul(4096, nv=95, amd=65) + def test_gemm_8192(self): self._test_matmul(8192, nv=130, amd=70) + + # TODO: add gemv, which is memory bounded + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 30a337759a..957f529c1b 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1219,7 +1219,7 @@ class TestLinearizer(unittest.TestCase): assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert sum(u.op is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg + assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index ce45ee53f7..b0830f7004 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -94,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase): prg = k.to_program() print(prg.src) if_uops = [u for u in k.uops if u.op is Ops.IF] - self.assertIn(len(if_uops), {1,3}) + self.assertIn(len(if_uops), {1,2,3}) conditions = if_uops[0].src[0].sparents self.assertLessEqual(len(conditions), 9) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index a316aace76..05dbe3c826 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase): k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) assert k is not None ifs = [u for u in k.uops if u.op is Ops.IF] - self.assertEqual(len(ifs), 4) + self.assertEqual(len(ifs), 3) #for st in k.uops.sink.src: self.assertEqual(len(st.src), 4) self.assertLessEqual(len(ifs[0].src[0].sparents), 17) diff --git a/test/test_ops.py b/test/test_ops.py index 62e69992e3..d7e9e0fab9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -233,6 +233,20 @@ class TestOps(unittest.TestCase): def test_arange_4096(self): helper_test_op([], lambda: torch.arange(4096, dtype=torch.int32), lambda: Tensor.arange(4096), forward_only=True) + def test_linspace(self): + helper_test_op([], lambda: torch.linspace(5, 10, 3), lambda: Tensor.linspace(5, 10, 3), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 1), lambda: Tensor.linspace(5, 10, 1), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 0), lambda: Tensor.linspace(5, 10, 0), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 30), lambda: Tensor.linspace(5, 10, 30), forward_only=True) + helper_test_op([], lambda: torch.linspace(-5.5, 5.5, 10), lambda: Tensor.linspace(-5.5, 5.5, 10), forward_only=True) + helper_test_op([], lambda: torch.linspace(5.5, -5.5, 10), lambda: Tensor.linspace(5.5, -5.5, 10), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 3, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 3, dtype="int32"), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, 10, 20, dtype=torch.int32), lambda: Tensor.linspace(5, 10, 20, dtype="int32"), forward_only=True) + helper_test_op([], lambda: torch.linspace(5, -5, 20, dtype=torch.int32), lambda: Tensor.linspace(5, -5, 20, dtype="int32"), forward_only=True) + self.helper_test_exception([], lambda: torch.linspace(5, 10, 3, dtype=torch.bool), lambda: Tensor.linspace(5, 10, 3, dtype="bool"), + expected=(RuntimeError, ValueError)) + self.helper_test_exception([], lambda: torch.linspace(1, 2, -1), lambda: Tensor.linspace(1, 2, -1), expected=(RuntimeError, ValueError)) + def test_sum_fake(self): helper_test_op([(256, 1)], lambda x: x.sum(axis=1)) @@ -1483,6 +1497,11 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w).relu(), lambda x,w: Tensor.conv2d(x,w).relu(), grad_rtol=1e-5) + def test_simple_conv2d_bias(self): + helper_test_op([(1,4,9,9), (4,4,3,3), (4,)], + lambda x,w,b: torch.nn.functional.conv2d(x,w,b).relu(), + lambda x,w,b: Tensor.conv2d(x,w,b).relu(), grad_rtol=1e-5) + @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_simple_conv3d(self): helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], diff --git a/test/test_randomness.py b/test/test_randomness.py index 4258b45b30..707e1e9cd4 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -6,6 +6,7 @@ import torch from tinygrad import nn, dtypes, Tensor, Device, TinyJit from tinygrad.helpers import getenv, CI from tinygrad.device import is_dtype_supported +from tinygrad.engine.realize import lower_schedule, CompiledRunner from hypothesis import given, settings, strategies as strat settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) @@ -92,10 +93,16 @@ class TestRandomness(unittest.TestCase): counts = Tensor.arange(20, dtype=dtypes.uint32) counts0, counts1 = counts.chunk(2) - r = Tensor._threefry_random_bits(1337 << 32, counts0, counts1).numpy() + r = Tensor._threefry_random_bits(Tensor([0, 1337], dtype='uint32'), counts0, counts1).numpy() np.testing.assert_allclose(jr, r) + def test_threefry_doesnt_use_long(self): + for ei in lower_schedule(Tensor.rand(20).schedule()): + if isinstance(ei.prg, CompiledRunner): + for u in ei.prg.p.uops: + self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}") + def test_threefry_against_reference_full(self): Tensor.manual_seed(1337) @@ -227,11 +234,15 @@ class TestRandomness(unittest.TestCase): def test_randint(self): self.assertFalse(normal_test(Tensor.randint)) - self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) + self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), + numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) + self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5, dtype="int32"), + numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x))) self.assertTrue(Tensor.randint(1, device="CLANG").device=="CLANG") # check types of args with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0.1, high=3) with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3.5) + with self.assertRaises(TypeError): Tensor.randint((3, 4), low=1, high=3, dtype="float") with self.assertRaises(TypeError): Tensor.randint((3, 4), low=0, high=3, dtype=dtypes.float32) def test_normal(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index 2a48022ed2..7b8d719904 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left +from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left from tinygrad.engine.realize import CompiledRunner, run_schedule from tinygrad.engine.lazy import LazyBuffer, view_supported_devices from test.helpers import ast_const, timeit @@ -1626,13 +1626,14 @@ class TestIndexing(unittest.TestCase): ref = Tensor(X).interpolate(size=(2, 2), mode="linear").numpy() np.testing.assert_allclose(ref, compare, atol=1e-5, rtol=1e-6) - def test_recursive_st_fixup(self): + def test_recursive_swizzle(self): a = Tensor([1,2,3,4]).realize() for _ in range(24): a = a + a ast = a.schedule()[0].ast - new_uop, et = timeit(st_fixup, ast.src[0].src[2], lambda st:st.reshape((4, 1)), {}) + swizzle = ast.src[0].src[2].reshape((4, 1)) + new_uop = swizzle_rewrite(swizzle) self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) - self.assertLess(et, 1e3) + self.assertEqual(swizzle_cnt(new_uop), 0) def test_strongly_connected_DAG(self): val = 1.0 @@ -1864,5 +1865,47 @@ class TestSwizzle(unittest.TestCase): ret = swizzle_rewrite(sink) self.assertEqual(swizzle_cnt(ret), 0) + @unittest.expectedFailure + def test_swizzle_failure_permute(self): + sink = UOp(Ops.SINK, dtypes.void, arg=None, src=( + UOp(Ops.STORE, dtypes.void, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(20, ('METAL', 65, dtypes.float)), src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 65), strides=(0, 1), offset=0, mask=None, contiguous=True),)), src=()), + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.ADD, dtypes.float, arg=None, src=( + UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 2925, dtypes.float)), src=()), + x10:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()),)), + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x12:=UOp(Ops.VALID, dtypes.bool, arg=None, src=( + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), + UOp(Ops.CONST, dtypes.float, arg=1.0, src=()), + x15:=UOp(Ops.CONST, dtypes.float, arg=0.0, src=()),)),)), + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x12, + UOp(Ops.CONST, dtypes.float, arg=0.0003418803389649838, src=()), + x15,)),)), + x6,)),)), + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.WHERE, dtypes.float, arg=None, src=( + x12, + UOp(Ops.CONST, dtypes.float, arg=-1.0, src=()), + x15,)), + UOp(Ops.MUL, dtypes.float, arg=None, src=( + UOp(Ops.PRELOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 2925, dtypes.float)), src=()), + x10,)), + UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(1, 89), offset=44, mask=None, contiguous=False),)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( + UOp(Ops.LOAD, dtypes.float, arg=None, src=( + UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(4, ('METAL', 2925, dtypes.float)), src=()), + UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(65, 45, 90), strides=(1, 0, 65), offset=0, mask=((0, 65), (0, 45), (0, 45)), contiguous=False), View(shape=(65, 4094), strides=(4050, 1), offset=0, mask=((0, 65), (0, 4050)), contiguous=False), View(shape=(1, 65, 46, 89), strides=(0, 4094, 89, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)) + ret = swizzle_rewrite(sink) + self.assertEqual(swizzle_cnt(ret), 0) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index d2fc35e9b1..d911ac004d 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -38,6 +38,7 @@ def helper_test_speed(f1, *args): # operation cache defeats args = [(x+1).realize() if isinstance(x, Tensor) else (None if x is None else (x+1)) for x in args] + args = [(x-1).realize() if isinstance(x, Tensor) else (None if x is None else (x-1)) for x in args] # force syncing [x.numpy() if isinstance(x, Tensor) or str(torch_device) == "cpu" else x.cpu().numpy() for x in args if x is not None] @@ -91,8 +92,9 @@ def helper_test_generic(name, f1, f1_args, f2, f2_args): desc = "faster" if et_torch > et_tinygrad else "slower" flops = save_ops*1e-6 mem = save_mem*1e-6 - print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:8.2f} GFLOPS {mem/et_torch:8.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:8.2f} GFLOPS {mem/et_tinygrad:8.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501 - np.testing.assert_allclose(val_tinygrad, val_torch, atol=1e-3, rtol=1e-3) + print(("\r" if not CI else "")+f"{name:42s} {et_torch:7.2f} ms ({flops/et_torch:9.2f} GFLOPS {mem/et_torch:7.2f} GB/s) in torch, {et_tinygrad:7.2f} ms ({flops/et_tinygrad:9.2f} GFLOPS {mem/et_tinygrad:7.2f} GB/s) in tinygrad, {colorize_float(et_tinygrad/et_torch)} {desc} {flops:10.2f} MOPS {mem:8.2f} MB") # noqa: E501 + atol, rtol = (1e-2, 1e-2) if torch_dt == torch.float16 else (1e-3, 1e-3) + np.testing.assert_allclose(val_tinygrad, val_torch, atol=atol, rtol=rtol) def helper_test_conv(bs, in_chans, out_chans, kernel_size, img_size_y, img_size_x): torch.manual_seed(0) diff --git a/test/test_tiny.py b/test/test_tiny.py index e690c731ee..cd276f4141 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -29,6 +29,14 @@ class TestTiny(unittest.TestCase): self.assertListEqual((out:=a@b).flatten().tolist(), [1.0]*(N*N)) if IMAGE < 2: self.assertEqual(out.dtype, out_dtype) + # *** randomness *** + + def test_random(self): + out = Tensor.rand(10) + for x in out.tolist(): + self.assertGreaterEqual(x, 0.0) + self.assertLessEqual(x, 1.0) + # *** JIT (for Python speed) *** def test_jit(self): diff --git a/test/testextra/test_mockgpu.py b/test/testextra/test_mockgpu.py new file mode 100644 index 0000000000..ec6b76a3d4 --- /dev/null +++ b/test/testextra/test_mockgpu.py @@ -0,0 +1,13 @@ +from tinygrad.helpers import getenv +import unittest, importlib + +@unittest.skipUnless(getenv("MOCKGPU"), 'Testing mockgpu') +class TestMockGPU(unittest.TestCase): + # https://github.com/tinygrad/tinygrad/pull/7627 + def test_import_typing_extensions(self): + import extra.mockgpu.mockgpu # noqa: F401 # pylint: disable=unused-import + import typing_extensions + importlib.reload(typing_extensions) # pytest imports typing_extension before mockgpu + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 5b9f091898..1ee8585320 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -274,8 +274,17 @@ sym = symbolic_flat+PatternMatcher([ # tensor core cleanups (UPat.var("add") + UPat(Ops.WMMA, name="wmma"), lambda add, wmma: UOp(wmma.op, wmma.dtype, (wmma.src[0], wmma.src[1], wmma.src[2]+add), wmma.arg)), - # threefry + # threefry + remove longs (UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32), + (UPat.var('x', dtypes.uint32).cast(dtypes.uint64).cast(dtypes.uint32), lambda x: x), # cast there and back is noop (TODO: genericize) + ((UPat.var('x', dtypes.uint64)&0xFFFFFFFF).cast(dtypes.uint32), lambda x: x.cast(dtypes.uint32)), # cast does truncation + (((UPat.var(None, dtypes.uint64)*(1<<32)) | UPat.var('y', dtypes.uint32).cast(dtypes.uint64)).cast(dtypes.uint32), lambda y: y), + (((UPat.var('x', dtypes.uint64)*(1<<32)) | UPat.var(None, dtypes.uint32).cast(dtypes.uint64))//(1<<32), lambda x: x), + # hacks for threefry long removal when padded (TODO: genericize) + (UPat.var('x', dtypes.uint32).cast(dtypes.uint64) * UPat.var('y').where(UPat.const(dtypes.uint64, 1<<32), UPat.const(dtypes.uint64, 0)), + lambda x,y: y.where(x, UOp.const(dtypes.uint32, 0)).cast(dtypes.uint64) * (1<<32)), + ((UPat.var('x', dtypes.uint64)&(UPat.var('y').where(UPat.const(dtypes.uint64, 0xFFFFFFFF), UPat.const(dtypes.uint64, 0)))).cast(dtypes.uint32), + lambda x,y: y.where(x.cast(dtypes.uint32), UOp.const(dtypes.uint32, 0))), # arange loop folding (acc_pat.assign(UPat.any(arange_m, arange_m+UPat.var("extra"))+acc_pat), loop_collapse), # indexing, with cast or where diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py index 3a701d1caa..95b29bfa93 100644 --- a/tinygrad/engine/fuse.py +++ b/tinygrad/engine/fuse.py @@ -1,6 +1,7 @@ from collections import defaultdict, deque -from typing import Tuple, List, Dict, DefaultDict -from tinygrad.ops import GroupOp, MetaOps, ReduceOps, UOp, UnaryOps +from typing import Set, Tuple, List, Dict, DefaultDict +from tinygrad.device import Buffer +from tinygrad.ops import GroupOp, MetaOps, Ops, ReduceOps, UOp, UnaryOps from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.engine.lazy import LazyBuffer @@ -37,13 +38,15 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={}) return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) -def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], - double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], ctx) -> List[List[UOp]]: +def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], + double_reduces:Dict[LazyBuffer, None], ubuf_realizes:Dict[UOp, UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]: """search the graph for all the LazyBuffers that need to realize""" # get all the realizes from big graph realizes: Dict[LazyBuffer, None] = {} + assigns: Set[UOp] = set() for r in allbufs: - if ctx.buf_uops[r.buffer] in ubuf_realizes: realizes[r] = None + if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None + if r.op is Ops.ASSIGN: assigns.add(ubuf) # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} reduce_of_const: List[LazyBuffer] = [] @@ -62,7 +65,7 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La parents = deque((r, *group)) while parents and not forced_realize: if (p:=parents.pop().base).is_realized() or p in realizes: - if p.is_realized() and p.buffer in ctx.assigns and not any(x.buffer is p.buffer for x in group): forced_realize, can_chase = True, False + if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False continue parents.extend(p.srcs) if forced_realize or not group: @@ -92,19 +95,19 @@ def get_realizes(outs:List[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[La top_reduce = reduceop.base.srcs[0].base if len(children[top_reduce]) == 1: del realizes[top_reduce] - if (ubuf:=ctx.buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] + if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] for r in reduce_of_const: group = {tr:None for tr,rop in reduce_for_op.items() if rop is r} - if any(tr.forced_realize for tr in group) or any(x.base in group for x in outs): continue + if any(tr.forced_realize for tr in group): continue kernel_children = {c for tr in group for c in children[tr] if c.op not in {MetaOps.COPY, MetaOps.BUFFER_VIEW}} if len(kernel_children) == 0: continue for tr in group: del realizes[tr] - if (ubuf:=ctx.buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] + if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] output_groups: DefaultDict[LazyBuffer, List[UOp]] = defaultdict(list) for buf in realizes: - output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=ctx.buf_uops[buf.buffer]) + output_groups[reduce_for_op.get(buf, buf)].append(ubuf:=buf_uops[buf.buffer]) ubuf_realizes[ubuf] = ubuf return list(output_groups.values()) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 74484d6153..54e6d43bca 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,9 +1,9 @@ import sys, atexit, functools, itertools from collections import defaultdict, deque from dataclasses import dataclass, field -from typing import Callable, Set, Tuple, List, Dict, Optional, DefaultDict, cast +from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint -from tinygrad.helpers import DEBUG, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap +from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape @@ -91,12 +91,8 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children, allbufs, double_reduce # ** helpers for doing movementops on uops -def st_fixup(u:UOp, apply_to_st:Callable[[ShapeTracker], ShapeTracker], cache:Dict[UOp, UOp]) -> UOp: - if (n:=cache.get(u)) is not None: return n - if u.op is Ops.VIEW: return u.replace(arg=apply_to_st(u.arg)) - if len(u.src) == 0 or (u.st is not None and u.st == apply_to_st(u.st)): return u - cache[u] = ret = u.replace(src=tuple(st_fixup(x, apply_to_st, cache) for x in u.src)) - return ret +def apply_swizzle(u:UOp, arg:ShapeTracker) -> UOp: + with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u.view(arg), view_left) def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis @@ -105,9 +101,8 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr # ** movementops rewrite rules -def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]: - if (st:=unwrap(view.st)).contiguous: return None - tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg) +def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: + tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg) prshape = prod(rshape) strides = strides_for_shape(rshape) nv: List[View] = [] @@ -118,10 +113,10 @@ def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]: new_input_st = tmp + ShapeTracker(tuple(nv)) _, new_rshape = permute_reduce(new_input_st, r.axis_arg) new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape))) - return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) + return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) -def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp: - swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st) +def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp: + swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st) assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE" assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE" output_shape = swizzle_st.reduce(root.axis_arg) @@ -134,9 +129,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles] assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" new_shape, new_input_shape = swizzle_shapes[0] - fixup_cache: Dict[UOp, UOp] = {} - new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src] - ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg) + ret = root.replace(src=tuple(x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src)) return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: @@ -149,8 +142,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, name="s0"),), # push VIEW to loads view_left = merge_views+PatternMatcher([ # VIEW before elementwise ops - (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), - lambda e,v: e.replace(src=tuple(s.view(v.st) if s.has_st else s for s in e.src))), + (UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"), lambda e,v: e.replace(src=tuple(s.view(v.st) for s in e.src))), # early merge VIEW buffer ops (UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.arg+v.arg).to_uop() if s.op is Ops.VIEW else s for s in b.src))), ]) @@ -160,10 +152,10 @@ view_right = merge_views+PatternMatcher([ # ASSIGN can override st (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))), lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None), - # VIEW on a reduce creates a new VIEW - (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), + # non contiguous VIEW on a reduce creates a new VIEW + (UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), # push a VIEW down to STORE, through a reduce (ONLY reshapes) - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), + (UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes) (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), @@ -283,7 +275,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] # get realizes realizes: Dict[UOp, UOp] = {} graph_rewrite(big_graph, do_realize, realizes) - store_groups = get_realizes(outs, children, allbufs, double_reduces, realizes, ctx) + store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.buf_uops) # split realizes into small graphs graph_rewrite(big_graph, break_sched, realizes) sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups] diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index b51c5083cc..d1824f976c 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -2,9 +2,7 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip import urllib.request, subprocess, shutil, math, contextvars, types, copyreg, inspect, importlib from dataclasses import dataclass -from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence -if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10 - from typing_extensions import TypeGuard +from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard T = TypeVar("T") U = TypeVar("U") @@ -320,10 +318,7 @@ class trange(tqdm): def _reconstruct_code(*args): return types.CodeType(*args) def _serialize_code(code:types.CodeType): - # NOTE: this works in Python 3.8 and up - if sys.version_info >= (3, 10): args = inspect.signature(types.CodeType).parameters - else: args = ['argcount', 'posonlyargcount', 'kwonlyargcount', 'nlocals', 'stacksize', 'flags', 'codestring', - 'constants', 'names', 'varnames', 'filename', 'name', 'firstlineno', 'lnotab', 'freevars', 'cellvars'] + args = inspect.signature(types.CodeType).parameters # NOTE: this works in Python 3.10 and up return _reconstruct_code, tuple(code.__getattribute__('co_'+x.replace('codestring', 'code').replace('constants', 'consts')) for x in args) copyreg.pickle(types.CodeType, _serialize_code) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 06cbbeb952..393ea7bc47 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from collections import defaultdict from weakref import WeakValueDictionary from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate -from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, T +from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context, partition, temp, unwrap, T if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker @@ -140,7 +140,7 @@ class Ops(FastEnum): # BinaryOps ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702 - SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto() # noqa: E702 + SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702 # TernaryOps WHERE = auto(); MULACC = auto() # noqa: E702 @@ -168,7 +168,8 @@ class Ops(FastEnum): class GroupOp: Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG} - Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, Ops.SUB} + Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY, + Ops.SUB, Ops.FDIV} Ternary = {Ops.WHERE, Ops.MULACC} ALU = set.union(Unary, Binary, Ternary) @@ -310,7 +311,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): return ret def sink(self, *srcs:UOp): return UOp(Ops.SINK, dtypes.void, (self,)+srcs) def index(self, idx:UOp, valid:Optional[UOp]=None): return UOp(Ops.INDEX, self.dtype, (self,idx,valid) if valid is not None else (self,idx)) - def view(self, st:ShapeTracker): return UOp(Ops.VIEW, self.dtype, (self,), st) def const_like(self, b:ConstLike): return UOp.const(self.dtype, b) def broadcast(self, count:int): assert self.dtype.count == 1 @@ -346,6 +346,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in GroupOp.Reduce else op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) + # *** uop movement ops *** + + def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st) + def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) + # *** uop Variable stuff *** @staticmethod @@ -379,7 +384,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def const_factor(self) -> int: """largest known int that divides self""" if self.op is Ops.CONST: return self.arg - if self.op is Ops.VCONST: return functools.reduce(math.gcd, self.arg) + if self.op is Ops.VCONST: return math.gcd(*self.arg) if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 @@ -630,7 +635,6 @@ class PatternMatcher: for p,fxn in self.patterns: assert p.op is not None tuple_fxn = fxn if isinstance(fxn, tuple) else deconstruct_function(fxn) - tuple_fxn[1]['__builtins__'] = __builtins__ # NOTE: Python 3.8 requires this for "all" and "len" and friends real_fxn = types.FunctionType(*tuple_fxn) for uop in p.op: self.pdict.setdefault(uop, []).append((p, real_fxn, p.early_reject, 'ctx' in inspect.signature(real_fxn).parameters)) @@ -901,7 +905,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: def lt_folding(x:UOp, c:int) -> Optional[UOp]: p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1) - if np and (d:=functools.reduce(math.gcd, [u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: + if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d) return None @@ -1050,6 +1054,8 @@ symbolic = symbolic_simple+PatternMatcher([ (UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None), # group like ((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y), + # ** boolean algebra ** + (UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x # ** combine terms ** (UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1) (UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 9e5061a5be..64e645b697 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,50 +1,77 @@ -from typing import Dict, Callable, List, Optional -from llvmlite import ir -from tinygrad.dtype import DType, PtrDType, dtypes -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, Ops, UOp, GroupOp +from typing import List, Dict, cast +import math, struct from tinygrad.renderer import Renderer +from tinygrad.ops import UOp, PatternMatcher, UPat, Ops, GroupOp +from tinygrad.dtype import dtypes, DType, PtrDType, truncate -MFLAGS = ('nsz', 'arcp', 'contract', 'afn') # All from fast math, but nnan and ninf and reassoc +def ldt(dt:DType): + if isinstance(dt, PtrDType): return ldt(dt.base) + "*" + return {dtypes.int8: "i8", dtypes.int16: "i16", dtypes.int32: "i32", dtypes.int64: "i64", + dtypes.uint8: "i8", dtypes.uint16: "i16", dtypes.uint32: "i32", dtypes.uint64: "i64", + dtypes.float16: "half", dtypes.float32: "float", dtypes.float64: "double", dtypes.bool: "i1", dtypes.void: "void"}[dt] -def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) - -dtype_to_llvm_dtype = { dtypes.bool:ir.IntType(1), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.int16:ir.IntType(16), - dtypes.uint16:ir.IntType(16), dtypes.int32:ir.IntType(32), dtypes.uint32:ir.IntType(32), dtypes.int64:ir.IntType(64), dtypes.uint64:ir.IntType(64), - dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType() } - -def cast(bb, val, input_type, output_type, bitcast=False): - if input_type == output_type: return val - llvm_type = dtype_to_llvm_dtype[output_type] - if bitcast: return bb[-1].bitcast(val, llvm_type) - - if input_type == dtypes.bfloat16: - val = bb[-1].bitcast(bb[-1].shl(bb[-1].sext(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)),val, ir.FloatType()) - input_type = dtypes.float32 - if output_type == dtypes.bfloat16: - val = cast(bb, val, input_type, dtypes.float32) - return bb[-1].trunc(bb[-1].lshr(bb[-1].bitcast(val, ir.IntType(32)), ir.Constant(ir.IntType(32), 16)), ir.IntType(16)) +def lconst(x, dtype:DType): + if dtype in dtypes.floats: + if math.isinf(x) or math.isnan(x): return "0x%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) + return truncate[dtype](x) + return int(x) +def lcast(input_type:DType, output_type:DType): if dtypes.is_float(input_type): - if dtypes.is_float(output_type): - return bb[-1].fpext(val, llvm_type) if output_type.itemsize > input_type.itemsize else bb[-1].fptrunc(val, llvm_type) - if dtypes.is_int(output_type): return bb[-1].fptoui(val, llvm_type) if dtypes.is_unsigned(output_type) else bb[-1].fptosi(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].fcmp_unordered('!=', cast(bb, val, input_type, dtypes.float32), ir.Constant(ir.FloatType(), 0)) - + if dtypes.is_float(output_type): return 'fpext' if output_type.itemsize > input_type.itemsize else 'fptrunc' + if dtypes.is_int(output_type): return 'fptoui' if dtypes.is_unsigned(output_type) else 'fptosi' if dtypes.is_unsigned(input_type) or input_type == dtypes.bool: - if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].uitofp(val, ir.FloatType()), ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].uitofp(val, dtype_to_llvm_dtype[output_type]) - if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].zext(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].icmp_unsigned('!=', val, ir.Constant(val.type, 0)) - + if dtypes.is_float(output_type): return 'uitofp' + if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'zext' if dtypes.is_int(input_type): - if output_type == dtypes.float16: return bb[-1].fptrunc(bb[-1].sitofp(val, ir.FloatType()), ir.HalfType()) - if dtypes.is_float(output_type): return bb[-1].sitofp(val, llvm_type) - if dtypes.is_int(output_type): return bb[-1].trunc(val, llvm_type) if input_type.itemsize > output_type.itemsize else bb[-1].sext(val, llvm_type) - if output_type == dtypes.bool: return bb[-1].icmp_signed('!=', val, ir.Constant(val.type, 0)) - + if dtypes.is_float(output_type): return 'sitofp' + if dtypes.is_int(output_type): return 'trunc' if output_type.itemsize < input_type.itemsize else 'sext' raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented") -def const(args, dtype): return ir.Constant(dtype_to_llvm_dtype[dtype], args) +# llvm ops, lop[][] +unsigned_lop = { Ops.ADD: "add", Ops.MUL: "mul", Ops.IDIV: "udiv", Ops.MOD: "urem", + Ops.CMPLT: "icmp ult", Ops.CMPNE: "icmp ne", Ops.OR: "or", Ops.AND: "and", Ops.XOR: "xor", } +signed_lop = {**unsigned_lop, Ops.CMPLT: "icmp slt", Ops.IDIV: "sdiv", Ops.MOD: "srem"} +flags = " nsz arcp contract afn" +float_lop = {Ops.ADD: "fadd"+flags, Ops.MUL: "fmul"+flags, Ops.CMPLT: f"fcmp{flags} ult", Ops.CMPNE: f"fcmp{flags} une", Ops.FDIV: "fdiv"+flags} +lop = {**{x:unsigned_lop for x in (dtypes.bool,)+dtypes.uints}, **{x:signed_lop for x in dtypes.sints}, **{x:float_lop for x in dtypes.floats}} + +llvm_rewrite = PatternMatcher([ + # memory load/store + (UPat(Ops.INDEX, name="x"), lambda ctx,x: + f" {ctx[x]} = getelementptr inbounds {ldt(x.dtype.base)}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}"), + (UPat(Ops.LOAD, src=(UPat.var('idx'), UPat.var('alt'), UPat.var('mask')), name="x"), lambda ctx,x,idx,alt,mask: + f" br label {ctx[x]}_entry\n{ctx[x][1:]}_entry:\n" + f" br i1 {ctx[mask]}, label {ctx[x]}_load, label {ctx[x]}_exit\n{ctx[x][1:]}_load:\n" + f" {ctx[x]}_yes = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}\n" + f" br label {ctx[x]}_exit\n{ctx[x][1:]}_exit:\n" + f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x]}_yes, {ctx[x]}_load], [{ctx[alt]}, {ctx[x]}_entry]"), + (UPat(Ops.LOAD, src=(UPat.var('idx'),), name="x"), lambda ctx,x,idx: f" {ctx[x]} = load {ldt(x.dtype)}, {ldt(idx.dtype)} {ctx[idx]}"), + (UPat(Ops.STORE, name="x"), lambda ctx,x: f" store {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[0].dtype)} {ctx[x.src[0]]}"), + + # unary/binary/ternary ops + (UPat(Ops.SQRT, name="x"), lambda ctx,x: + f" {ctx[x]} = call{flags} {ldt(x.dtype)} @llvm.sqrt.{ldt(x.src[0].dtype)}({ldt(x.src[0].dtype)} {ctx[x.src[0]]})"), + (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f" {ctx[x]} = bitcast {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), + (UPat(Ops.CAST, name="x"), lambda ctx,x: f" {ctx[x]} = {lcast(x.src[0].dtype, x.dtype)} {ldt(x.src[0].dtype)} {ctx[x.src[0]]} to {ldt(x.dtype)}"), + (UPat(GroupOp.Binary, name="x"), lambda ctx,x: f" {ctx[x]} = {lop[x.src[0].dtype][x.op]} {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ctx[x.src[1]]}"), + (UPat(Ops.WHERE, name="x"), lambda ctx,x: + f" {ctx[x]} = select {ldt(x.src[0].dtype)} {ctx[x.src[0]]}, {ldt(x.src[1].dtype)} {ctx[x.src[1]]}, {ldt(x.src[2].dtype)} {ctx[x.src[2]]}"), + + # range + (UPat(Ops.RANGE, name="x"), lambda ctx,x: + f" br label %loop_entry_{x.arg[0]}\nloop_entry_{x.arg[0]}:\n" + f" br label %loop_body_{x.arg[0]}\nloop_body_{x.arg[0]}:\n" + f" {ctx[x]} = phi {ldt(x.dtype)} [{ctx[x.src[0]]}, %loop_entry_{x.arg[0]}], [{ctx[x]}phi, %loop_latch_{x.arg[0]}]"), + (UPat(Ops.ENDRANGE, name="x"), lambda ctx,x: + f" br label %loop_latch_{x.src[0].arg[0]}\nloop_latch_{x.src[0].arg[0]}:\n" + f" {ctx[x.src[0]]}phi = add i32 {ctx[x.src[0]]}, 1\n {ctx[x]} = icmp ult i32 {ctx[x.src[0]]}phi, {ctx[x.src[0].src[1]]}\n" + f" br i1 {ctx[x]}, label %loop_body_{x.src[0].arg[0]}, label %loop_exit_{x.src[0].arg[0]}\nloop_exit_{x.src[0].arg[0]}:"), + + # if + (UPat(Ops.IF, name="x"), lambda ctx,x: f" br i1 {ctx[x.src[0]]}, label %ifbody_{ctx[x][1:]}, label %ifskip_{ctx[x][1:]}\nifbody_{ctx[x][1:]}:"), + (UPat(Ops.ENDIF, name="x"), lambda ctx,x: f" br label %ifskip_{ctx[x.src[0]][1:]}\nifskip_{ctx[x.src[0]][1:]}:"), +]) class LLVMRenderer(Renderer): device = "LLVM" @@ -52,101 +79,64 @@ class LLVMRenderer(Renderer): has_local = False has_shared = False global_max = None - code_for_op: Dict[Ops, Callable] = { - UnaryOps.RECIP: lambda builder, x, dtype: builder.fdiv(const(1, dtype), x, flags=MFLAGS), - UnaryOps.SQRT: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sqrt', [x.type]), [x], fastmath=MFLAGS), - BinaryOps.ADD: lambda builder, x, y, dtype: builder.or_(x, y) if dtype == dtypes.bool else builder.add(x, y) if dtypes.is_int(dtype) else builder.fadd(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.MUL: lambda builder, x, y, dtype: builder.mul(x, y) if is_bool_or_unsigned(dtype) or dtypes.is_int(dtype) else builder.fmul(x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.IDIV: lambda builder, x, y, dtype: builder.udiv(x, y) if is_bool_or_unsigned(dtype) else builder.sdiv(x, y), - BinaryOps.CMPLT: lambda builder, x, y, dtype: builder.icmp_unsigned("<", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("<", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("<", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.CMPNE: lambda builder, x, y, dtype: builder.icmp_unsigned("!=", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed("!=", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered("!=", x, y, flags=MFLAGS), # noqa: E501 - BinaryOps.MAX: lambda builder, x, y, dtype: builder.select(builder.icmp_unsigned(">", x, y) if is_bool_or_unsigned(dtype) else builder.icmp_signed(">", x, y) if dtypes.is_int(dtype) else builder.fcmp_unordered(">", x, y, flags=MFLAGS), x, y), # noqa: E501 - BinaryOps.MOD: lambda builder, x, y, dtype: builder.urem(x, y) if is_bool_or_unsigned(dtype) else builder.srem(x, y) if dtypes.is_int(dtype) else builder.frem(x, y), # noqa: E501 - BinaryOps.XOR: lambda builder, x, y, dtype: builder.xor(x, y), BinaryOps.AND: lambda builder, x, y, dtype: builder.and_(x, y), BinaryOps.OR: lambda builder, x, y, dtype: builder.or_(x, y), # noqa: E501 - BinaryOps.SHL: lambda builder, x, y, dtype: builder.shl(x, y), BinaryOps.SHR: lambda builder, x, y, dtype: builder.lshr(x, y) if dtypes.is_unsigned(dtype) else builder.ashr(x, y), # noqa: E501 - TernaryOps.WHERE: lambda builder, x, y, z, dtype: builder.select(x, y, z)} - def render(self, name:str, uops:List[UOp]) -> str: - # all llvm stuff goes into a module - module = ir.Module(name=__file__) + extra_matcher = PatternMatcher([ + # rewrite RECIP with FDIV + (UPat(Ops.RECIP, name="x"), lambda x: UOp(Ops.FDIV, x.dtype, (x.const_like(1), x.src[0]))), + # rewrite cast to bool to CMPNE 0 + (UPat(Ops.CAST, dtype=dtypes.bool, name="x"), lambda x: x.src[0] != x.src[0].const_like(0)), + # *** also in cstyle *** + # gate any stores that aren't gated with ifs + (UPat(Ops.STORE, dtype=dtypes.void, src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="store"), + lambda store: UOp(Ops.STORE, src=store.src[:2]+(UOp(Ops.IF, src=(store.src[2],)),))), + # rewrite MAX to CMPLT + WHERE + (UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])), + ]) - # extract global buffers (NOTE: this isn't right if DEFINE_GLOBAL is out of order) - buf_to_dtype = {u.arg:u.dtype for u in uops if u.op in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}} - buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())} + def render(self, name: str, uops: List[UOp]) -> str: + r: Dict[UOp, str] = {} + args: List[str] = [] + kernel: List[str] = [] + end_lines: Dict[str, None] = {} + vc = -1 - # create llvm function - func_dtypes = [(dtype_to_llvm_dtype[dtype.base if isinstance(dtype, PtrDType) else dtype],dtype) for dtype in buf_to_dtype.values()] - func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() if isinstance(dt, PtrDType) else x for x,dt in func_dtypes]), name=name) - for a in func.args: - if a.type.is_pointer: a.add_attribute("noalias") - - bb = [ir.IRBuilder(func.append_basic_block("entry"))] - loop_blocks: List = [] - reduce_phis: List = [] - lvars: Dict[Optional[UOp], ir.Instruction] = {} - - for bufname,dtype in buf_to_dtype.items(): - if not isinstance(dtype, PtrDType) and dtype == dtypes.int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) + # prealloc all assigns + acc_to_assign: Dict[UOp, UOp] = {} + for u in uops: + if u.op is Ops.ASSIGN: + vc += 1 + r[u] = r[u.src[1]] = f"%assign{vc}" + assert u.src[0] not in acc_to_assign, "can't assign to DEFINE_ACC twice" + acc_to_assign[u.src[0]] = u.src[1] for u in uops: - uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is Ops.INDEX: - lvars[u] = bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True) - elif uop is Ops.STORE: - if len(src) > 2: - with bb[-1].if_then(lvars[src[2]]): bb[-1].store(lvars[src[1]], lvars[src[0]]) - else: - bb[-1].store(lvars[src[1]], lvars[src[0]]) - elif uop is Ops.ENDRANGE: - loop_entry_bb, phis = loop_blocks.pop() - idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1)) - lvars[src[0]].add_incoming(idx_p1, bb[-1].block) - for n,phi in phis: phi.add_incoming(lvars[n], bb[-1].block) - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_exit_{len(loop_blocks)}"))) - bb[-2].cbranch(bb[-2].icmp_unsigned("<", idx_p1, lvars[src[0].src[1]]), loop_entry_bb, bb[-1].block) + # hack for defining sqrt function (TODO: can we get a transcendental for this?) + if u.op is Ops.SQRT: end_lines[f'declare {ldt(u.dtype)} @llvm.sqrt.{ldt(u.dtype)}({ldt(u.dtype)} %".1")'] = None + + if u.op in (Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR): + r[u] = f"%data{u.arg}" if u.op is Ops.DEFINE_GLOBAL else f"%{u.arg[0]}" + args.append(f"{ldt(u.dtype)}{' noalias' if isinstance(u.dtype, PtrDType) else ''} {r[u]}") + elif u.op is Ops.ASSIGN: pass # assign is already handled by the first pass + elif u.op is Ops.DEFINE_ACC: r[u] = r[u.src[0]] # a define acc can be used and never be assigned to + elif u.op is Ops.CONST: r[u] = lconst(u.arg, u.dtype) + elif u.op is Ops.CAST and ldt(u.dtype) == ldt(u.src[0].dtype): r[u] = r[u.src[0]] # cast from signed to unsigned of the same size is a noop else: - if uop is Ops.RANGE: - bb.append(ir.IRBuilder(func.append_basic_block(f"loop_body_{len(loop_blocks)}"))) - bb[-2].branch(bb[-1].block) + # if it's an assign target, it's already preallocated + if u not in r: + vc += 1 + r[u] = f"%v{vc}" - phis = [] - for rp in reduce_phis: - incoming = lvars[rp] - lvars[rp] = bb[-1].phi(dtype_to_llvm_dtype[rp.dtype]) - lvars[rp].add_incoming(incoming, bb[-2].block) - phis.append((rp, lvars[rp])) + # do the rendering of the llvm ir code + if (l:=llvm_rewrite.rewrite(u, ctx=r)) is None: raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.src]}") + kernel.append(cast(str, l)) - lvars[u] = bb[-1].phi(ir.IntType(32), name=f"loop{len(loop_blocks)}") - lvars[u].add_incoming(lvars[src[0]], bb[-2].block) - loop_blocks.append((bb[-1].block, phis)) - elif uop is Ops.DEFINE_ACC: - lvars[u] = const(src[0].arg, dtype) - reduce_phis.append(u) - elif uop is Ops.LOAD: - if len(src) > 1: - with bb[-1].if_else(lvars[src[2]]) as (then, otherwise): - with then: - val1 = bb[-1].load(lvars[src[0]]) - then_blk = bb[-1].block - with otherwise: otherwise_blk = bb[-1].block - val = bb[-1].phi(val1.type) - val.add_incoming(val1, then_blk) - val.add_incoming(lvars[src[1]], otherwise_blk) - else: - val = bb[-1].load(lvars[src[0]]) - lvars[u] = val - elif uop is Ops.ASSIGN: - lvars[u] = lvars[src[1]] - # ASSIGN UOps can link to other ASSIGN Uops, backtrace this to DEFINE_ACC - backward = src[0] - while backward.op is Ops.ASSIGN: backward = backward.src[0] - lvars[backward] = lvars[u] - elif uop in GroupOp.ALU: - lvars[u] = self.code_for_op[uop](bb[-1], *[lvars[x] for x in src], src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype) - elif uop in {Ops.CAST, Ops.BITCAST}: lvars[u] = cast(bb, lvars[src[0]], src[0].dtype, dtype, bitcast=uop is Ops.BITCAST) - elif uop in {Ops.DEFINE_GLOBAL, Ops.DEFINE_VAR}: lvars[u] = func.args[buf_index[args]] - elif uop is Ops.CONST: lvars[u] = const(args, dtype) - else: raise RuntimeError(f"failed to render {uop}") + # generate the phi nodes for the assigns + if u.op is Ops.RANGE: + for x in acc_to_assign: + if u in x.src: # if this range is relevent for this acc + vc += 1 + kernel.append(f" %acc{vc} = phi {ldt(x.dtype)}" f"[{r[x]}, %loop_entry_{u.arg[0]}], [{r[acc_to_assign[x]]}, %loop_latch_{u.arg[0]}]") + r[x] = f"%acc{vc}" - bb[-1].ret_void() - return str(module) + # output the function + return f"define void @{name}({','.join(args)}) {{\n" + '\n'.join(kernel) + "\n ret void\n}\n"+'\n'.join(end_lines.keys()) diff --git a/tinygrad/runtime/ops_cloud.py b/tinygrad/runtime/ops_cloud.py index 501338436b..a6d60ad571 100644 --- a/tinygrad/runtime/ops_cloud.py +++ b/tinygrad/runtime/ops_cloud.py @@ -5,14 +5,70 @@ # it should be a secure (example: no use of pickle) boundary. HTTP is used for RPC from __future__ import annotations -from typing import Tuple, Optional, Dict, Any, DefaultDict +from typing import Tuple, Optional, Dict, Any, DefaultDict, List from collections import defaultdict -import multiprocessing, functools, http.client, hashlib, json, time, contextlib, os, binascii from dataclasses import dataclass, field -from tinygrad.dtype import dtypes -from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, prod -from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions +import multiprocessing, functools, http.client, hashlib, json, time, os, binascii, struct, ast, contextlib from http.server import HTTPServer, BaseHTTPRequestHandler +from tinygrad.renderer import Renderer +from tinygrad.dtype import dtypes +from tinygrad.helpers import getenv, DEBUG, fromimport, unwrap, Timing +from tinygrad.device import Compiled, Allocator, Compiler, Device, BufferOptions + +# ***** API ***** + +class CloudRequest: pass + +@dataclass(frozen=True) +class BufferAlloc(CloudRequest): buffer_num: int; size: int; options: BufferOptions # noqa: E702 + +@dataclass(frozen=True) +class BufferFree(CloudRequest): buffer_num: int # noqa: E702 + +@dataclass(frozen=True) +class CopyIn(CloudRequest): buffer_num: int; datahash: str # noqa: E702 + +@dataclass(frozen=True) +class CopyOut(CloudRequest): buffer_num: int + +@dataclass(frozen=True) +class ProgramAlloc(CloudRequest): name: str; datahash: str # noqa: E702 + +@dataclass(frozen=True) +class ProgramFree(CloudRequest): name: str; datahash: str # noqa: E702 + +@dataclass(frozen=True) +class ProgramExec(CloudRequest): + name: str; datahash: str; bufs: Tuple[int, ...]; vals: Tuple[int, ...] # noqa: E702 + global_size: Optional[Tuple[int, ...]]; local_size: Optional[Tuple[int, ...]]; wait: bool # noqa: E702 + +# for safe deserialization +whitelist = {x.__name__:x for x in [BufferAlloc, BufferFree, CopyIn, CopyOut, ProgramAlloc, ProgramFree, ProgramExec, BufferOptions]} +eval_fxns = {ast.Constant: lambda x: x.value, ast.Tuple: lambda x: tuple(map(safe_eval, x.elts)), ast.List: lambda x: list(map(safe_eval, x.elts)), + ast.Call: lambda x: safe_eval(x.func)(*[safe_eval(arg) for arg in x.args], **{kwarg.arg: safe_eval(kwarg.value) for kwarg in x.keywords}), + ast.Name: lambda x: whitelist[x.id], ast.Attribute: lambda x: {"imagef": dtypes.imagef, "imageh": dtypes.imageh}[x.attr]} +def safe_eval(node): return eval_fxns[node.__class__](node) + +class BatchRequest: + def __init__(self): + self._q: List[CloudRequest] = [] + self._h: Dict[str, bytes] = {} + def h(self, d:bytes) -> str: + binhash = hashlib.sha256(d).digest() + self._h[datahash:=binascii.hexlify(binhash).decode()] = binhash+struct.pack(" bytes: + self.h(repr(self._q).encode()) + return b''.join(self._h.values()) + def deserialize(self, dat:bytes) -> BatchRequest: + ptr = 0 + while ptr < len(dat): + datahash, datalen = binascii.hexlify(dat[ptr:ptr+0x20]).decode(), struct.unpack("= 1: print(c) + match c: + case BufferAlloc(): + assert c.buffer_num not in session.buffers, f"buffer {c.buffer_num} already allocated" + session.buffers[c.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(c.size, c.options), c.size, c.options) + case BufferFree(): + buf,sz,buffer_options = session.buffers[c.buffer_num] + Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options) + del session.buffers[c.buffer_num] + case CopyIn(): Device[CloudHandler.dname].allocator.copyin(session.buffers[c.buffer_num][0], memoryview(bytearray(req._h[c.datahash]))) + case CopyOut(): + buf,sz,_ = session.buffers[c.buffer_num] + Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf) + case ProgramAlloc(): + lib = Device[CloudHandler.dname].compiler.compile_cached(req._h[c.datahash].decode()) + session.programs[(c.name, c.datahash)] = Device[CloudHandler.dname].runtime(c.name, lib) + case ProgramFree(): del session.programs[(c.name, c.datahash)] + case ProgramExec(): + bufs = [session.buffers[x][0] for x in c.bufs] + extra_args = {k:v for k,v in [("global_size", c.global_size), ("local_size", c.local_size)] if v is not None} + r = session.programs[(c.name, c.datahash)](*bufs, vals=c.vals, wait=c.wait, **extra_args) + if r is not None: ret = str(r).encode() + elif self.path == "/renderer" and method == "GET": cls, args = Device[CloudHandler.dname].renderer.__reduce__() ret = json.dumps((cls.__module__, cls.__name__, args)).encode() - elif self.path.startswith("/alloc") and method == "POST": - size = int(self.path.split("=")[-1]) - buffer_options: Optional[BufferOptions] = None - if 'image' in self.path: - image_shape = tuple([int(x) for x in self.path.split("=")[-2].split("&")[0].split(",")]) - buffer_options = BufferOptions(image=dtypes.imageh(image_shape) if prod(image_shape)*2 == size else dtypes.imagef(image_shape)) - session.buffer_num += 1 - session.buffers[session.buffer_num] = (Device[CloudHandler.dname].allocator.alloc(size, buffer_options), size, buffer_options) - ret = str(session.buffer_num).encode() - elif self.path.startswith("/buffer"): - key = int(self.path.split("/")[-1]) - buf,sz,buffer_options = session.buffers[key] - if method == "GET": Device[CloudHandler.dname].allocator.copyout(memoryview(ret:=bytearray(sz)), buf) - elif method == "PUT": Device[CloudHandler.dname].allocator.copyin(buf, memoryview(bytearray(self.get_data()))) - elif method == "DELETE": - Device[CloudHandler.dname].allocator.free(buf,sz,buffer_options) - del session.buffers[key] - else: return self._fail() - elif self.path.startswith("/program"): - name, hsh = self.path.split("/")[-2:] - if method == "PUT": - src = self.get_data() - assert hashlib.sha256(src).hexdigest() == hsh - lib = Device[CloudHandler.dname].compiler.compile_cached(src.decode()) - session.programs[(name, hsh)] = Device[CloudHandler.dname].runtime(name, lib) - elif method == "POST": - j = self.get_json() - bufs = [session.buffers[x][0] for x in j['bufs']] - del j['bufs'] - r = session.programs[(name, hsh)](*bufs, **j) - if r is not None: ret = str(r).encode() - elif method == "DELETE": del session.programs[(name, hsh)] - else: return self._fail() - else: return self._fail() - self.send_response(200) + else: status_code = 404 + self.send_response(status_code) self.send_header('Content-Length', str(len(ret))) self.end_headers() return self.wfile.write(ret) def do_GET(self): return self._do("GET") def do_POST(self): return self._do("POST") - def do_PUT(self): return self._do("PUT") - def do_DELETE(self): return self._do("DELETE") def cloud_server(port:int): multiprocessing.current_process().name = "MainProcess" @@ -106,44 +142,46 @@ class CloudAllocator(Allocator): def __init__(self, device:CloudDevice): self.device = device super().__init__() - def _alloc(self, size:int, options) -> int: - # TODO: ideally we shouldn't have to deal with images here - extra = ("image="+','.join([str(x) for x in options.image.shape])+"&") if options.image is not None else "" - return int(self.device.send("POST", f"alloc?{extra}size={size}")) - def _free(self, opaque, options): - with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): - self.device.send("DELETE", f"buffer/{opaque}", data=b"") - def copyin(self, dest:int, src:memoryview): self.device.send("PUT", f"buffer/{dest}", data=bytes(src)) + # TODO: ideally we shouldn't have to deal with images here + def _alloc(self, size:int, options:BufferOptions) -> int: + self.device.buffer_num += 1 + self.device.req.q(BufferAlloc(self.device.buffer_num, size, options)) + return self.device.buffer_num + # TODO: options should not be here in any Allocator + def _free(self, opaque:int, options): self.device.req.q(BufferFree(opaque)) + def copyin(self, dest:int, src:memoryview): self.device.req.q(CopyIn(dest, self.device.req.h(bytes(src)))) def copyout(self, dest:memoryview, src:int): - resp = self.device.send("GET", f"buffer/{src}") + self.device.req.q(CopyOut(src)) + resp = self.device.batch_submit() assert len(resp) == len(dest), f"buffer length mismatch {len(resp)} != {len(dest)}" dest[:] = resp class CloudProgram: def __init__(self, device:CloudDevice, name:str, lib:bytes): - self.device = device - self.prgid = f"{name}/{hashlib.sha256(lib).hexdigest()}" - self.device.send("PUT", "program/"+self.prgid, lib) + self.device, self.name = device, name + self.datahash = self.device.req.h(lib) + self.device.req.q(ProgramAlloc(self.name, self.datahash)) super().__init__() - def __del__(self): self.device.send("DELETE", "program/"+self.prgid) + def __del__(self): self.device.req.q(ProgramFree(self.name, self.datahash)) def __call__(self, *bufs, global_size=None, local_size=None, vals:Tuple[int, ...]=(), wait=False): - args = {"bufs": bufs, "vals": vals, "wait": wait} - if global_size is not None: args["global_size"] = global_size - if local_size is not None: args["local_size"] = local_size - ret = self.device.send("POST", "program/"+self.prgid, json.dumps(args).encode()) - if wait: return float(ret) + self.device.req.q(ProgramExec(self.name, self.datahash, bufs, vals, global_size, local_size, wait)) + if wait: return float(self.device.batch_submit()) class CloudDevice(Compiled): def __init__(self, device:str): - if (host:=getenv("HOST", "")) != "": - self.host = host + if (host:=getenv("HOST", "")) != "": self.host = host else: p = multiprocessing.Process(target=cloud_server, args=(6667,)) p.daemon = True p.start() self.host = "127.0.0.1:6667" - self.cookie = binascii.hexlify(os.urandom(0x10)).decode() + + # state for the connection + self.session = binascii.hexlify(os.urandom(0x10)).decode() + self.buffer_num = 0 + self.req: BatchRequest = BatchRequest() + if DEBUG >= 1: print(f"cloud with host {self.host}") while 1: try: @@ -155,13 +193,26 @@ class CloudDevice(Compiled): time.sleep(0.1) if DEBUG >= 1: print(f"remote has device {clouddev}") # TODO: how to we have BEAM be cached on the backend? this should just send a specification of the compute. rethink what goes in Renderer - assert clouddev[0].startswith("tinygrad.renderer."), f"bad renderer {clouddev}" - renderer = fromimport(clouddev[0], clouddev[1])(*clouddev[2]) - super().__init__(device, CloudAllocator(self), renderer, Compiler(), functools.partial(CloudProgram, self)) + if not clouddev[0].startswith("tinygrad.renderer.") or not clouddev[1].endswith("Renderer"): raise RuntimeError(f"bad renderer {clouddev}") + renderer_class = fromimport(clouddev[0], clouddev[1]) # TODO: is this secure? + if not issubclass(renderer_class, Renderer): raise RuntimeError(f"renderer isn't a Renderer {clouddev}") + super().__init__(device, CloudAllocator(self), renderer_class(*clouddev[2]), Compiler(), functools.partial(CloudProgram, self)) + + def __del__(self): + # TODO: this is never being called + # TODO: should close the whole session + with contextlib.suppress(ConnectionRefusedError, http.client.CannotSendRequest, http.client.RemoteDisconnected): self.batch_submit() + + def batch_submit(self): + data = self.req.serialize() + with Timing(f"*** send {len(self.req._q):-3d} requests {len(self.req._h):-3d} hashes with len {len(data)/1024:.2f} kB in ", enabled=DEBUG>=1): + ret = self.send("POST", "batch", data) + self.req = BatchRequest() + return ret def send(self, method, path, data:Optional[bytes]=None) -> bytes: # TODO: retry logic - self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.cookie}"}) + self.conn.request(method, "/"+path, data, headers={"Cookie": f"session={self.session}"}) response = self.conn.getresponse() assert response.status == 200, f"failed on {method} {path}" return response.read() diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index c897b2e574..0078c74fec 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -2,18 +2,24 @@ from __future__ import annotations import ctypes, functools from typing import Tuple from tinygrad.device import Compiled, Compiler, MallocAllocator -from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump +from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump, getenv from tinygrad.renderer.llvmir import LLVMRenderer import llvmlite.binding as llvm class LLVMCompiler(Compiler): - def __init__(self, device:LLVMDevice): + def __init__(self, device:LLVMDevice, opt:bool=False): self.device = device - super().__init__("compile_llvm") + self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager() + self.device.target_machine.add_analysis_passes(self.optimizer) + if opt: + with llvm.create_pass_manager_builder() as builder: + builder.opt_level = 3; builder.size_level = 0; builder.loop_vectorize = True; builder.slp_vectorize = True # noqa: E702 + builder.populate(self.optimizer) + super().__init__("compile_llvm_opt" if opt else "compile_llvm") def compile(self, src:str) -> bytes: mod = llvm.parse_assembly(src) mod.verify() - self.device.optimizer.run(mod) + self.optimizer.run(mod) if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod)) return self.device.target_machine.emit_object(mod) @@ -23,6 +29,7 @@ class LLVMProgram: self.name, self.lib = name, lib device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) self.fxn = device.engine.get_function_address(name) + assert self.fxn != 0, "LLVM failed to get function address" def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False): if not hasattr(self, 'cfunc'): @@ -35,12 +42,9 @@ class LLVMDevice(Compiled): llvm.initialize_native_target() llvm.initialize_native_asmprinter() llvm.initialize_native_asmparser() - self.optimizer: llvm.passmanagers.ModulePassManager = llvm.create_module_pass_manager() # this opt actually can change things. ex: opt=3 means no FMA, opt=2 means FMA self.target_machine: llvm.targets.TargetMachine = llvm.Target.from_triple(llvm.get_process_triple()).create_target_machine(opt=2) - self.target_machine.add_analysis_passes(self.optimizer) - self.target_machine.set_asm_verbosity(True) backing_mod = llvm.parse_assembly(str()) backing_mod.triple = llvm.get_process_triple() self.engine: llvm.executionengine.ExecutionEngine = llvm.create_mcjit_compiler(backing_mod, self.target_machine) - super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self), functools.partial(LLVMProgram, self)) + super().__init__(device, MallocAllocator, LLVMRenderer(), LLVMCompiler(self, getenv("LLVMOPT")), functools.partial(LLVMProgram, self)) diff --git a/tinygrad/runtime/ops_qcom.py b/tinygrad/runtime/ops_qcom.py index 903860f07d..f9a6128f4f 100644 --- a/tinygrad/runtime/ops_qcom.py +++ b/tinygrad/runtime/ops_qcom.py @@ -155,7 +155,7 @@ class QCOMComputeQueue(HWComputeQueue): if args_state.prg.tex_cnt > 0: self.cmd(adreno.CP_LOAD_STATE6_FRAG, qreg.cp_load_state6_0(state_type=adreno.ST_CONSTANTS, state_src=adreno.SS6_INDIRECT, - state_block=adreno.SB6_CS_TEX, num_unit=args_state.prg.tex_cnt), + state_block=adreno.SB6_CS_TEX, num_unit=min(16, args_state.prg.tex_cnt)), *data64_le(args_state.ptr + args_state.prg.tex_off)) self.reg(adreno.REG_A6XX_SP_CS_TEX_CONST, *data64_le(args_state.ptr + args_state.prg.tex_off)) @@ -247,14 +247,15 @@ class QCOMProgram(HCQProgram): self.buf_info, self.consts_info = [], [] # Collect sampler info. - self.samp_cnt = _read_lib(image_desc_off + 0xdc) + self.samp_cnt = samp_cnt_in_file = _read_lib(image_desc_off + 0xdc) assert self.samp_cnt <= 1, "Up to one sampler supported" if self.samp_cnt: + self.samp_cnt += 1 self.samplers = [qreg.a6xx_tex_samp_0(wrap_s=(clamp_mode:=adreno.A6XX_TEX_CLAMP_TO_BORDER), wrap_t=clamp_mode, wrap_r=clamp_mode), - qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0] + qreg.a6xx_tex_samp_1(unnorm_coords=True, cubemapseamlessfiltoff=True), 0, 0, 0, 0, 0, 0] # Collect kernel arguments (buffers) info. - bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * self.samp_cnt + bdoff = round_up(image_desc_off + 0x158 + len(self.name), 4) + 8 * samp_cnt_in_file while bdoff + 32 <= len(self.lib): length, _, _, offset_words, _, _, _, typ = struct.unpack("IIIIIIII", self.lib[bdoff:bdoff+32]) if length == 0: break @@ -263,7 +264,7 @@ class QCOMProgram(HCQProgram): # Setting correct offsets to textures/ibos. self.tex_cnt, self.ibo_cnt = sum(x.type is BUFTYPE_TEX for x in self.buf_info), sum(x.type is BUFTYPE_IBO for x in self.buf_info) - self.samp_off, self.ibo_off, self.tex_off = 2048, 2048 + 0x10 * self.samp_cnt, 2048 + 0x10 * self.samp_cnt + 0x40 * self.ibo_cnt + self.ibo_off, self.tex_off, self.samp_off = 2048, 2048 + 0x40 * self.ibo_cnt, 2048 + 0x40 * self.tex_cnt + 0x40 * self.ibo_cnt cur_ibo_off, cur_tex_off = self.ibo_off, self.tex_off for x in self.buf_info: if x.type is BUFTYPE_IBO: x.offset, cur_ibo_off = cur_ibo_off, cur_ibo_off + 0x40 @@ -307,10 +308,10 @@ class QCOMAllocator(HCQAllocator): texture.pitch, texture.real_stride = pitch, real_stride tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT - texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt) + texture.desc[0] = qreg.a6xx_tex_const_0(0x8, swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt) texture.desc[1] = qreg.a6xx_tex_const_1(width=imgw, height=imgh) texture.desc[2] = qreg.a6xx_tex_const_2(type=adreno.A6XX_TEX_2D, pitch=texture.pitch, pitchalign=pitchalign-6) - texture.desc[4:7] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000)] + texture.desc[4:8] = [*data64_le(texture.va_addr), qreg.a6xx_tex_const_6(plane_pitch=0x400000), qreg.a6xx_tex_const_7(13)] texture.ibo = [texture.desc[0] & (~0xffff), *texture.desc[1:len(texture.desc)]] return texture diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cc76bad61d..ed94e33d0d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -464,9 +464,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method Tensor._seed, Tensor._device_seeds, Tensor._device_rng_counters = seed, {}, {} @staticmethod - def _threefry_random_bits(key, counts0, counts1): + def _threefry_random_bits(key:Tensor, counts0:Tensor, counts1:Tensor): x = (counts1.cast(dtypes.uint64) << 32) | counts0.cast(dtypes.uint64) - x = F.Threefry.apply(*x._broadcasted(key)) + x = F.Threefry.apply(x, (key[1]._broadcast_to(x.shape).cast(dtypes.uint64) << 32) | key[0]._broadcast_to(x.shape).cast(dtypes.uint64)) counts0, counts1 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) return counts0.cat(counts1) @@ -494,9 +494,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # generate per device seeds and rng counter if we haven't seen this device yet if device not in Tensor._device_seeds: - Tensor._device_seeds[device] = Tensor([((Tensor._seed & 0xffffffff) << 32) \ - | int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff], - device=device, dtype=dtypes.uint64, requires_grad=False) + Tensor._device_seeds[device] = Tensor( + [int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big"), Tensor._seed], + device=device, dtype=dtypes.uint32, requires_grad=False) Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False) had_counter = False else: had_counter = True @@ -612,6 +612,26 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if (output_len:=ceildiv(stop-start, step)) <= 0: return Tensor([], dtype=dtype, **kwargs) return (Tensor.full((output_len,), step, dtype=dtype, **kwargs)._cumsum() + (start - step)).cast(dtype) + @staticmethod + def linspace(start:Union[int, float], stop:Union[int, float], steps:int, **kwargs) -> Tensor: + """ + Returns a 1-D tensor of `steps` evenly spaced values from `start` to `stop`, inclusive. + + You can pass in `dtype` and `device` keyword arguments to control the data type and device of the tensor. + Additionally, all other keyword arguments are passed to the constructor of the tensor. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.linspace(0, 10, 5).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.linspace(-1, 1, 5).numpy()) + ``` + """ + if steps < 0: raise ValueError("number of steps must be non-negative") + if (dtype := to_dtype(kwargs.pop("dtype", dtypes.default_float))) == dtypes.bool: raise ValueError("linspace with bool dtype is not supported") + if steps == 1: return Tensor([start], dtype=dtype, **kwargs) + return (start + Tensor.arange(steps, **kwargs) * ((stop - start) / (steps - 1))).cast(dtype) + @staticmethod def eye(n:int, m:Optional[int]=None, **kwargs) -> Tensor: """ @@ -734,7 +754,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ if not isinstance(low, int) or not isinstance(high, int): raise TypeError(f"{low=} and {high=} must be integers") - dtype = kwargs.pop("dtype", dtypes.int32) + dtype = to_dtype(kwargs.pop("dtype", dtypes.int32)) if not dtypes.is_int(dtype): raise TypeError(f"{dtype=} must be int") return Tensor.uniform(*shape, low=low, high=high, dtype=dtype, **kwargs)