mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
Remove float64 (#1101)
* Refactor: Remove float64 * Refactor: Remove unused imports * Refactor: Remove float64 * Refactor: Remove float64 * Refactor: Exclude float64 onnx backend * Add: Skip jacobian and gradcheck tests;
This commit is contained in:
@@ -62,7 +62,7 @@ class RDNACodegen(AssemblyCodegen):
|
||||
return rtor[x]
|
||||
for uop, out, vin, arg in asm:
|
||||
if uop == UOps.DEFINE_REGISTER:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float64, dtypes._float4]:
|
||||
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes._float4]:
|
||||
for i in range(arg[2]):
|
||||
# TODO: Re-use gaps created by this to avoid wasting registers
|
||||
align = int(arg[0][0].itemsize / 4)
|
||||
|
||||
@@ -176,7 +176,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
if lang.double_prekernel and any(x.dtype == dtypes.float64 for x in bufs): prg = ''.join([f"{lang.double_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
|
||||
class CStyleCodegen(Linearizer):
|
||||
|
||||
@@ -37,7 +37,7 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
# create llvm function
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.float64:ir.DoubleType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}[buf.dtype] for buf in bufs]
|
||||
func_dtypes = [{dtypes.float16:ir.HalfType(), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}[buf.dtype] for buf in bufs]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
|
||||
|
||||
# force llvmlite to allow us to add function attribute then add the attribute
|
||||
@@ -90,8 +90,6 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
if func_dtypes[args.i] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif bufs[args.i].dtype == dtypes.float64:
|
||||
val = bb[-1].fptrunc(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
lvars[newvar] = val
|
||||
@@ -102,8 +100,6 @@ def uops_to_llvm_ir(uops:List[UOp], bufs:List[LazyBuffer]) -> str:
|
||||
if func_dtypes[0] != ir.FloatType():
|
||||
if dtypes.is_int(bufs[args.i].dtype):
|
||||
element = bb[-1].fptoui(element, func_dtypes[0]) if dtypes.is_unsigned(bufs[args.i].dtype) else bb[-1].fptosi(element, func_dtypes[0])
|
||||
elif bufs[args.i].dtype == dtypes.float64:
|
||||
element = bb[-1].fpext(element, func_dtypes[0])
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, func_dtypes[0])
|
||||
bb[-1].store(element, bb[-1].gep(func.args[args.i], [idx], inbounds=True))
|
||||
|
||||
@@ -82,7 +82,7 @@ class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.uint8, dtypes.int32, dtypes.int64)
|
||||
@staticmethod
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float4)
|
||||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes._half4, dtypes._float4)
|
||||
@staticmethod
|
||||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint32, dtypes.uint64)
|
||||
@staticmethod
|
||||
@@ -94,7 +94,6 @@ class dtypes:
|
||||
half = float16
|
||||
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
||||
float = float32
|
||||
float64: Final[DType] = DType(5, 8, "double", np.float64)
|
||||
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
||||
int32: Final[DType] = DType(1, 4, "int", np.int32)
|
||||
int64: Final[DType] = DType(2, 8, "long", np.int64)
|
||||
|
||||
@@ -39,7 +39,7 @@ class RawBufferMapped(RawBufferCopyIn):
|
||||
|
||||
# this one is simple enough that i moved it out of the runtimes
|
||||
class RawMallocBuffer(RawBufferMapped):
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.float64: ctypes.c_double, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def __init__(self, size, dtype: DType): super().__init__(size, dtype, ({dtypes.float32: ctypes.c_float, dtypes.float16: ctypes.c_int16, dtypes.int8: ctypes.c_int8, dtypes.uint8: ctypes.c_uint8, dtypes.bool: ctypes.c_uint8, dtypes.int32: ctypes.c_int32, dtypes.int64: ctypes.c_int64}[dtype] * size)())
|
||||
def _buffer(self): return memoryview(self._buf)
|
||||
|
||||
class RawBufferCopyInOut(RawBufferCopyIn):
|
||||
|
||||
@@ -3,7 +3,7 @@ import pathlib
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Optional, List
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, dtypes, fromimport
|
||||
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferCopyInOut
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
@@ -30,7 +30,6 @@ CL = _CL()
|
||||
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
assert not OSX or dtype != dtypes.float64, "OpenCL on Mac doesn't support float64"
|
||||
if isinstance(dtype, ImageDType):
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, {2: cl.channel_type.HALF_FLOAT, 4: cl.channel_type.FLOAT}[dtype.itemsize])
|
||||
buf = cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(dtype.shape[1], dtype.shape[0]))
|
||||
|
||||
@@ -3,7 +3,7 @@ import os, subprocess, pathlib
|
||||
import Metal, Cocoa, libdispatch # type: ignore
|
||||
from typing import List, Any
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.runtime.lib import RawBufferMapped
|
||||
|
||||
@@ -22,7 +22,6 @@ METAL = _METAL()
|
||||
|
||||
class RawMetalBuffer(RawBufferMapped):
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
assert dtype != dtypes.float64, "metal doesn't support float64"
|
||||
super().__init__(size, dtype, METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared))
|
||||
def __del__(self):
|
||||
self._buf.release()
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if getenv("MPS", 0) else "cpu"))
|
||||
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.float64: dtypes.float64, torch.double: dtypes.float64, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
|
||||
type_map = {torch.float16: dtypes.float16, torch.float32: dtypes.float32, torch.int8: dtypes.int8, torch.int32: dtypes.int32, torch.int64: dtypes.int64, torch.uint8: dtypes.uint8, torch.bool: dtypes.bool}
|
||||
inverse_type_map = {v:k for k,v in type_map.items()}
|
||||
|
||||
torch_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
|
||||
Reference in New Issue
Block a user