move is_dtype_supported to device [pr] (#7575)

This commit is contained in:
George Hotz
2024-11-07 20:38:03 +08:00
committed by GitHub
parent 1f5ea1e412
commit 205befa788
23 changed files with 50 additions and 44 deletions

View File

@@ -5,7 +5,7 @@ import onnx.backend.test
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, OSX
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
# pip3 install tabulate
pytest_plugins = 'onnx.backend.test.report',

View File

@@ -26,7 +26,7 @@ from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UnaryOps, UOp, Ops
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
def on_linearizer_will_run(): pass
def on_linearizer_did_run(): pass

View File

@@ -1,14 +1,14 @@
import sys, time, logging, difflib
import time, logging, difflib
from typing import Callable, Optional, Tuple
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad import Tensor, dtypes
from tinygrad.ops import UOp, Ops, sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import _to_np_dtype
from tinygrad.engine.realize import Runner
from tinygrad.dtype import ConstType, DType
from tinygrad.nn.state import get_parameters
from tinygrad.helpers import CI, OSX, T, getenv, colored
from tinygrad.helpers import T, getenv, colored
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.runtime.ops_python import PythonProgram, PythonRenderer, PythonCompiler, PythonAllocator
@@ -31,23 +31,6 @@ def assert_jit_cache_len(fxn, expected_len):
assert type(fxn.jit_cache[0].prg).__name__.endswith('Graph')
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device == "GPU": return not CI and not OSX
if device in ["CUDA", "NV"]: return not CI
if device == "LLVM": return OSX
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
def rand_for_dtype(dt:DType, size:int):
if dtypes.is_unsigned(dt):
return np.random.randint(0, 100, size=size, dtype=_to_np_dtype(dt))

View File

@@ -1,12 +1,13 @@
import unittest, time, gc
import numpy as np
from tinygrad.device import is_dtype_supported
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.engine.jit import TinyJit
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
from tinygrad.helpers import CI, Context
from extra.lr_scheduler import OneCycleLR
from test.helpers import derandomize_model, is_dtype_supported
from test.helpers import derandomize_model
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
from examples.hlb_cifar10 import SpeedyResNet, hyp

View File

@@ -3,7 +3,7 @@ import pathlib
from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform
from tinygrad.helpers import CI, fetch
from tinygrad import Device, dtypes
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
# Audio generated with the command on MacOS:
# say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad import Tensor, dtypes, Device
from tinygrad.engine.realize import lower_schedule
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
class TestCompileFailures(unittest.TestCase):
def compile(self, out:Tensor):

View File

@@ -4,7 +4,7 @@ from tinygrad.ops import Ops
from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import CI
import numpy as np
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once

View File

@@ -2,12 +2,13 @@ import unittest, operator, subprocess, math
import numpy as np
import torch
from typing import Any, List
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype
from test.helpers import rand_for_dtype
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")

View File

@@ -10,7 +10,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import GroupOp
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
import pytest
pytestmark = pytest.mark.filterwarnings("ignore")

View File

@@ -2,12 +2,13 @@
import unittest, random
import numpy as np
from tinygrad.codegen.kernel import Kernel, KernelOptError
from tinygrad.device import is_dtype_supported
from tinygrad.ops import UOp, Ops, BinaryOps
from tinygrad.engine.search import Opt, OptOps
from tinygrad import Device, dtypes, Tensor
from tinygrad.helpers import CI
from test.external.fuzz_linearizer import compare_linearizer
from test.helpers import is_dtype_supported, ast_const
from test.helpers import ast_const
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner
from tinygrad.multi import all_reduce, MultiLazyBuffer
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")

View File

@@ -10,7 +10,7 @@ from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNo
from tinygrad.nn.state import load_state_dict
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
@unittest.skipIf(CI and Device.DEFAULT in {"CUDA", "NV"}, "slow")
class TestNN(unittest.TestCase):

View File

@@ -5,7 +5,7 @@ import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, Context, TRANSCENDENTAL
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
if CI:
import warnings

View File

@@ -4,7 +4,7 @@ import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.nn.optim import Adam, SGD, AdamW
from tinygrad.helpers import CI
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
np.random.seed(1337)
x_init = np.random.randn(1,4).astype(np.float32)

View File

@@ -5,7 +5,7 @@ import numpy as np
import torch
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
from tinygrad.helpers import getenv, CI
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
from hypothesis import given, settings, strategies as strat
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))

View File

@@ -9,6 +9,7 @@ import functools
from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.device import is_dtype_supported
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
@@ -18,7 +19,7 @@ from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
from tinygrad.engine.realize import CompiledRunner, run_schedule
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from test.helpers import ast_const, is_dtype_supported, timeit
from test.helpers import ast_const, timeit
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass

View File

@@ -1,7 +1,7 @@
import unittest
from tinygrad.helpers import CI
from tinygrad import Tensor, Device, dtypes
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
# similar to test/external/external_test_gpu_ast.py, but universal
@unittest.skipIf(Device.DEFAULT in {"CUDA", "NV"} and CI, "slow on CUDA CI")

View File

@@ -7,7 +7,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.helpers import getenv, temp, CI, _METADATA, mv_address
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")

View File

@@ -4,7 +4,7 @@ from tinygrad.tensor import _to_np_dtype
from tinygrad.helpers import Context, getenv
from test.test_schedule import check_schedule
from test.test_dtype_alu import ht, dtypes_float
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
import numpy as np
from hypothesis import given, settings, strategies as strat

View File

@@ -12,7 +12,7 @@ from tinygrad.engine.schedule import create_schedule, to_si
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)

View File

@@ -4,7 +4,7 @@ from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import DType
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load, tar_extract
from tinygrad.helpers import Timing, fetch, temp, CI
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
def compare_weights_both(url):
import torch

View File

@@ -2,7 +2,7 @@ import os, unittest, ctypes
from tinygrad import dtypes, Tensor, fetch, Device
import numpy as np
from tinygrad.nn.state import ggml_data_to_tensor, gguf_load
from test.helpers import is_dtype_supported
from tinygrad.device import is_dtype_supported
try:
import ggml
except ModuleNotFoundError:

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib
from tinygrad.helpers import getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType, PtrDType
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib, sys
from tinygrad.helpers import CI, OSX, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes
from tinygrad.renderer import Renderer
# **************** Device ****************
@@ -200,3 +200,22 @@ class Compiled:
This method ensures that all previously queued operations on the device have been completed before proceeding.
"""
# override this in your device implementation
# TODO: move this to each Device
def is_dtype_supported(dtype:DType, device:Optional[str]=None) -> bool:
if device is None: device = Device.DEFAULT
if dtype == dtypes.bfloat16:
# NOTE: this requires bf16 buffer support
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
# for CI GPU and OSX, cl_khr_fp16 isn't supported
# for CI LLVM, it segfaults because it can't link to the casting function
# CI CUDA architecture is sm_35 but we need at least sm_70 to run fp16 ALUs
# PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751
if dtype == dtypes.half:
if device == "GPU": return not CI and not OSX
if device in ["CUDA", "NV"]: return not CI
if device == "LLVM": return OSX
if device == "PYTHON": return sys.version_info >= (3, 12)
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True