diff --git a/test/imported/test_indexing.py b/test/imported/test_indexing.py index a6e8598d9e..250e83ca81 100644 --- a/test/imported/test_indexing.py +++ b/test/imported/test_indexing.py @@ -79,7 +79,7 @@ def make_tensor(shape, dtype:dtypes, noncontiguous) -> Tensor: +---------------------------+------------+----------+ """ contiguous = not noncontiguous - if dtype is dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool) + if dtype == dtypes.bool: return Tensor.randint(shape=shape, low=0, high=2, contiguous=contiguous).cast(dtypes.bool) elif dtype.is_unsigned(): return Tensor.randint(shape=shape, low=0, high=10, contiguous=contiguous).cast(dtype) elif dtype.is_int(): return Tensor.randint(shape=shape, low=-9, high=10, contiguous=contiguous).cast(dtype) # signed int elif dtype.is_float(): return Tensor.rand(shape=shape, low=-9, high=9, dtype=dtype, contiguous=contiguous) @@ -452,7 +452,7 @@ class TestIndexing(unittest.TestCase): def tensor_indices_to_np(tensor: Tensor, indices): npt = tensor.numpy() - idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype is dtypes.int64 else + idxs = tuple(i.numpy().tolist() if isinstance(i, Tensor) and i.dtype == dtypes.int64 else i for i in indices) return npt, idxs diff --git a/test/test_dtype.py b/test/test_dtype.py index 571c8ab4af..53f89d87ba 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -113,7 +113,7 @@ class TestDType(unittest.TestCase): arr = np.asarray(data, dtype=dt) tin = Tensor(arr).numpy() tor = torch.as_tensor(arr).detach().numpy() - assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" + assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): diff --git a/test/test_schedule.py b/test/test_schedule.py index b11ced5935..35de8c127b 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -706,12 +706,12 @@ class TestSchedule(unittest.TestCase): a = shared * 2 b = shared * 3 sched = check_schedule([a, b], 1) - for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs) + for si in sched[:-2]: assert all(out.dtype == dtypes.half for out in si.outputs) # reduce a = z.sum(axis=0).half().float().sum(axis=0) sched = check_schedule(a, 2) - for si in sched[:-1]: assert all(out.dtype is dtypes.half for out in si.outputs) + for si in sched[:-1]: assert all(out.dtype == dtypes.half for out in si.outputs) # expand # expand will realize just after the .float(), so requires change to realize-before-expand diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 8a830d8caa..7452646861 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -224,7 +224,7 @@ constant_folder = PatternMatcher([ ({"uop": UOps.ALU, "arg": BinaryOps.CMPLT, "vin": ({"uop": UOps.ALU, "arg": UnaryOps.NEG, "vin": ({"__name__": "x"},)}, {"__name__": "c", "uop": UOps.CONST, "dtype": dtypes.int})}, lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)), - # cast NOOP (NOTE: it's str to deal with PtrDType) + # cast NOOP (NOTE: it's `is` to deal with PtrDType) ({"__name__": "root", "uop": UOps.CAST}, lambda root: root.vin[0] if root.dtype is root.vin[0].dtype else None), ]) diff --git a/tinygrad/device.py b/tinygrad/device.py index 8c3434b4a9..d6bf2b9d32 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -164,6 +164,8 @@ MallocAllocator = _MallocAllocator() # **************** for Compiled Devices **************** +class CompileError(Exception): pass + class Compiler: def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey def compile(self, src:str) -> bytes: raise NotImplementedError("need a compile function") diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index bd1d328b69..aca4b8db07 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -59,7 +59,7 @@ def _try_compile_linearized_w_idx(x:Tuple[int,Linearizer], compiler:Compiler) -> prog = compiler.compile(p.src) et = time.perf_counter() - st return x[0], (p, prog, et) - except Exception: + except RuntimeError: if DEBUG >= 4: traceback.print_exc() return x[0], None diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index 938a8fdf39..69d66ca2f2 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -76,7 +76,7 @@ class PTXRenderer(Renderer): def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: - assert dtype is not dtypes.bool + assert dtype != dtypes.bool if gate: return [f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"] return [f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}+{offset}];"] @@ -154,7 +154,7 @@ class PTXRenderer(Renderer): kk(f"{r_label[vin[0]]}:") elif uop is UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None - assert vin[0].dtype is dtypes.int64, "store isn't int64" + assert vin[0].dtype == dtypes.int64, "store isn't int64" assert vin[1].uop is UOps.CONST, f"store isn't const {u}" mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' if vin[2].dtype.count > 1: @@ -187,7 +187,7 @@ class PTXRenderer(Renderer): else: r[u] = const(args, dtype, mov=True) elif uop is UOps.GEP: r[u] = r[vin[0]][u.arg] elif uop is UOps.LOAD: - assert vin[0].dtype is dtypes.int64, "load isn't int64" + assert vin[0].dtype == dtypes.int64, "load isn't int64" assert vin[1].uop is UOps.CONST, f"load isn't const {u}" mem_type = '.shared' if vin[0].uop is UOps.DEFINE_LOCAL or any(x.uop is UOps.DEFINE_LOCAL for x in vin[0].parents) else '.global' if dtype.count > 1: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d9d7a5fb86..79a1146ac4 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -24,7 +24,7 @@ class CStyleLanguage(Renderer): uses_ptr_arithmetic: bool = False type_map: Dict[DType, str] = {} code_for_op: Dict = { - UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype is dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", + UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype == dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", @@ -43,8 +43,8 @@ class CStyleLanguage(Renderer): def render_const(self, x:ConstType, dtype:DType) -> str: if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" - elif dtype is dtypes.bool: val = "1" if x else "0" - elif dtype is dtypes.float: val = f"{x}f" + elif dtype == dtypes.bool: val = "1" if x else "0" + elif dtype == dtypes.float: val = f"{x}f" else: val = str(x) return (self.render_cast([val] * dtype.count, dtype) if dtype.count > 1 or dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index 31b3d707df..1cfb9d298a 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -12,7 +12,7 @@ def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_ code_for_op: Final[Dict[Op, Callable]] = { UnaryOps.NEG: lambda builder, x, dtype: builder.neg(x) if dtypes.is_int(dtype) else \ - (builder.not_(x) if dtype is dtypes.bool else builder.fneg(x, flags=MFLAGS)), + (builder.not_(x) if dtype == dtypes.bool else builder.fneg(x, flags=MFLAGS)), UnaryOps.EXP2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS), UnaryOps.LOG2: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS), UnaryOps.SIN: lambda builder, x, dtype: builder.call(builder.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS), diff --git a/tinygrad/runtime/ops_amd.py b/tinygrad/runtime/ops_amd.py index 78779860af..29d58b16da 100644 --- a/tinygrad/runtime/ops_amd.py +++ b/tinygrad/runtime/ops_amd.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import Tuple, List, Any, cast import os, fcntl, ctypes, ctypes.util, functools, re, pathlib, mmap, struct, errno, subprocess, time -from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator +from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, DEBUG from tinygrad.renderer.cstyle import AMDRenderer from tinygrad.runtime.driver.hip_comgr import compile_hip @@ -77,7 +77,9 @@ class AMDCompiler(Compiler): def __init__(self, arch:str): self.arch = arch super().__init__(f"compile_hip_{self.arch}") - def compile(self, src:str) -> bytes: return compile_hip(src, self.arch) + def compile(self, src:str) -> bytes: + try: return compile_hip(src, self.arch) + except RuntimeError as e: raise CompileError(e) PAGE_SIZE = 0x1000 SIGNAL_SIZE, SIGNAL_COUNT = ctypes.sizeof(hsa.amd_signal_t), 16384 diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 3264570e6a..6aa899269b 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Tuple, Optional, List import tinygrad.runtime.autogen.cuda as cuda from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, init_c_struct_t, colored, cpu_time_execution -from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator, MallocAllocator +from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator, MallocAllocator from tinygrad.renderer.cstyle import CUDARenderer from tinygrad.renderer.assembly import PTXRenderer if getenv("IOCTL"): import extra.nv_gpu_driver.nv_ioctl # noqa: F401 @@ -71,7 +71,7 @@ class CUDACompiler(Compiler): check(cuda.nvrtcCreateProgram(ctypes.byref(prog := cuda.nvrtcProgram()), src.encode(), "".encode(), 0, None, None)) status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])) - if status != 0: raise RuntimeError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}") + if status != 0: raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, check).decode()}") return _get_bytes(prog, cuda.nvrtcGetPTX, cuda.nvrtcGetPTXSize, check) def cuda_disassemble(lib, arch): diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index e682f6fdda..25a4d8b4f2 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -4,7 +4,7 @@ import ctypes, functools, hashlib import tinygrad.runtime.autogen.opencl as cl from tinygrad.helpers import init_c_var, to_char_p_p, from_mv, OSX, DEBUG from tinygrad.renderer.cstyle import OpenCLRenderer -from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler +from tinygrad.device import BufferOptions, LRUAllocator, Compiled, Compiler, CompileError # see test/external/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something OSX_TIMING_RATIO = (125/3) if OSX else 1.0 @@ -23,7 +23,7 @@ class CLCompiler(Compiler): if build_status != 0: cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, 0, None, log_size := ctypes.c_size_t()) cl.clGetProgramBuildInfo(program, self.device.device_id, cl.CL_PROGRAM_BUILD_LOG, log_size.value, mstr := ctypes.create_string_buffer(log_size.value), None) # noqa: E501 - raise RuntimeError(f"OpenCL Compile Error\n\n{mstr.value.decode()}") + raise CompileError(f"OpenCL Compile Error\n\n{mstr.value.decode()}") check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARY_SIZES, ctypes.sizeof(ctypes.c_size_t), binary_sizes := (ctypes.c_size_t * 1)(), None)) check(cl.clGetProgramInfo(program, cl.CL_PROGRAM_BINARIES, ctypes.sizeof(ctypes.c_void_p), (ctypes.c_void_p * 1)(ctypes.addressof(binary := ctypes.create_string_buffer(binary_sizes[0]))), None)) # noqa: E501 check(cl.clReleaseProgram(program)) diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index 86ff961101..8ddc7633ad 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -3,7 +3,7 @@ import ctypes, functools, subprocess, io, atexit, collections, json from typing import Tuple, TypeVar, List, Dict, Any import tinygrad.runtime.autogen.hsa as hsa from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t, getenv -from tinygrad.device import Compiled, Compiler, BufferOptions, LRUAllocator +from tinygrad.device import Compiled, Compiler, CompileError, BufferOptions, LRUAllocator from tinygrad.renderer.cstyle import HIPRenderer from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue from tinygrad.runtime.driver.hip_comgr import compile_hip @@ -45,7 +45,9 @@ class HSACompiler(Compiler): def __init__(self, arch:str): self.arch = arch super().__init__(f"compile_hip_{self.arch}") - def compile(self, src:str) -> bytes: return compile_hip(src, self.arch) + def compile(self, src:str) -> bytes: + try: return compile_hip(src, self.arch) + except RuntimeError as e: raise CompileError(e) class HSAProgram: def __init__(self, device:HSADevice, name:str, lib:bytes): diff --git a/tinygrad/runtime/ops_nv.py b/tinygrad/runtime/ops_nv.py index ac3d78921d..41ad81afae 100644 --- a/tinygrad/runtime/ops_nv.py +++ b/tinygrad/runtime/ops_nv.py @@ -1,7 +1,7 @@ from __future__ import annotations import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array from typing import Tuple, List, Any, cast -from tinygrad.device import Compiled, Compiler, LRUAllocator, BufferOptions +from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator, BufferOptions from tinygrad.helpers import getenv, from_mv, init_c_struct_t, to_mv, round_up, to_char_p_p, DEBUG, prod from tinygrad.renderer.cstyle import NVRenderer from tinygrad.runtime.ops_cuda import check as cuda_check, _get_bytes, CUDACompiler @@ -80,7 +80,7 @@ class NVCompiler(Compiler): status = cuda.nvrtcCompileProgram(prog, len(self.compile_options), to_char_p_p([o.encode() for o in self.compile_options])) if status != 0: - raise RuntimeError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, cuda_check).decode()}") + raise CompileError(f"compile failed: {_get_bytes(prog, cuda.nvrtcGetProgramLog, cuda.nvrtcGetProgramLogSize, cuda_check).decode()}") return _get_bytes(prog, cuda.nvrtcGetCUBIN, cuda.nvrtcGetCUBINSize, cuda_check) class HWComputeQueue: