mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
truncate consts early (#6741)
* truncate consts early * ptx still fails * Update dtype.py
This commit is contained in:
@@ -3,10 +3,9 @@ import numpy as np
|
||||
import torch
|
||||
from typing import Any, List
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.ops import truncate_fp16
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported, rand_for_dtype
|
||||
|
||||
|
||||
@@ -358,7 +358,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
shard_output.backward()
|
||||
shard_grad = m.conv1.weight.grad.numpy()
|
||||
# sometimes there is zeros in these grads... why?
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=3e-6, rtol=3e-6)
|
||||
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_multi_tensor_jit_param(self):
|
||||
@TinyJit
|
||||
|
||||
@@ -1520,7 +1520,7 @@ class TestIndexing(unittest.TestCase):
|
||||
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
|
||||
Y = Tensor([1, 2]).realize()
|
||||
loss = X.sparse_categorical_crossentropy(Y)
|
||||
self.check_schedule(loss, 6)
|
||||
self.check_schedule(loss, 4)
|
||||
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_mnist_val(self):
|
||||
@@ -1531,7 +1531,7 @@ class TestIndexing(unittest.TestCase):
|
||||
yt = Tensor.randn(BS, 10)
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||
self.check_schedule(loss, 7)
|
||||
self.check_schedule(loss, 6)
|
||||
loss_fused = loss.numpy()
|
||||
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
|
||||
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
|
||||
import math, struct, ctypes
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
from tinygrad.helpers import getenv
|
||||
@@ -63,6 +64,7 @@ class dtypes:
|
||||
if isinstance(val, tuple):
|
||||
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
|
||||
return tuple(dtypes.as_const(x, dtype) for x in val)
|
||||
# TODO: should truncate here
|
||||
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@@ -144,3 +146,15 @@ def sum_acc_dtype(dt:DType):
|
||||
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
|
||||
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
|
||||
return least_upper_dtype(dt, dtypes.float)
|
||||
|
||||
def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
# TODO: bfloat16
|
||||
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
|
||||
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
|
||||
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
|
||||
import sys, time, functools, itertools, math, operator, hashlib
|
||||
from enum import auto, IntEnum, Enum
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
if TYPE_CHECKING:
|
||||
@@ -310,18 +310,6 @@ python_alu: Dict[Op, Callable] = {
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def truncate_fp16(x):
|
||||
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
|
||||
except OverflowError: return math.copysign(math.inf, x)
|
||||
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
|
||||
# TODO: bfloat16
|
||||
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
|
||||
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
|
||||
|
||||
def exec_alu(op:Op, dtype:DType, operands):
|
||||
if dtype.count > 1:
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
|
||||
@@ -3164,7 +3164,7 @@ class Tensor:
|
||||
"""
|
||||
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
|
||||
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
|
||||
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)
|
||||
|
||||
Reference in New Issue
Block a user