truncate consts early (#6741)

* truncate consts early

* ptx still fails

* Update dtype.py
This commit is contained in:
George Hotz
2024-09-25 16:49:51 +08:00
committed by GitHub
parent e31552e2e0
commit cb22ef379a
6 changed files with 22 additions and 21 deletions

View File

@@ -3,10 +3,9 @@ import numpy as np
import torch
from typing import Any, List
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype, truncate_fp16
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from tinygrad.ops import truncate_fp16
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype

View File

@@ -358,7 +358,7 @@ class TestMultiTensor(unittest.TestCase):
shard_output.backward()
shard_grad = m.conv1.weight.grad.numpy()
# sometimes there is zeros in these grads... why?
np.testing.assert_allclose(grad, shard_grad, atol=3e-6, rtol=3e-6)
np.testing.assert_allclose(grad, shard_grad, atol=1e-5, rtol=1e-5)
def test_multi_tensor_jit_param(self):
@TinyJit

View File

@@ -1520,7 +1520,7 @@ class TestIndexing(unittest.TestCase):
X = Tensor([[0, 2, 3], [1, 2, 3]]).realize()
Y = Tensor([1, 2]).realize()
loss = X.sparse_categorical_crossentropy(Y)
self.check_schedule(loss, 6)
self.check_schedule(loss, 4)
np.testing.assert_allclose(loss.item(), 0.878309, atol=1e-5, rtol=1e-6)
def test_mnist_val(self):
@@ -1531,7 +1531,7 @@ class TestIndexing(unittest.TestCase):
yt = Tensor.randn(BS, 10)
with Context(SPLIT_REDUCEOP=0):
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
self.check_schedule(loss, 7)
self.check_schedule(loss, 6)
loss_fused = loss.numpy()
loss_ref = torch.nn.CrossEntropyLoss()(torch.tensor(yt.numpy()), torch.tensor(Y_train.numpy())[torch.tensor(samples.numpy())])
np.testing.assert_allclose(loss_fused, loss_ref.numpy(), atol=1e-6, rtol=1e-6)

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union, Callable
import math, struct, ctypes
from dataclasses import dataclass
import functools
from tinygrad.helpers import getenv
@@ -63,6 +64,7 @@ class dtypes:
if isinstance(val, tuple):
assert len(val) == dtype.count, f"mismatch {val} {dtype}"
return tuple(dtypes.as_const(x, dtype) for x in val)
# TODO: should truncate here
return int(val) if dtypes.is_int(dtype) else float(val) if dtypes.is_float(dtype) else bool(val)
@staticmethod
@functools.lru_cache(None)
@@ -144,3 +146,15 @@ def sum_acc_dtype(dt:DType):
if dtypes.is_unsigned(dt): return least_upper_dtype(dt, dtypes.uint)
if dtypes.is_int(dt) or dt == dtypes.bool: return least_upper_dtype(dt, dtypes.int)
return least_upper_dtype(dt, dtypes.float)
def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}

View File

@@ -1,10 +1,10 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, ctypes, struct, hashlib
import sys, time, functools, itertools, math, operator, hashlib
from enum import auto, IntEnum, Enum
from collections import defaultdict
from dataclasses import dataclass, field
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
from tinygrad.helpers import _CURRENT_KERNEL, ContextVar, pretty_print, prod, getenv, all_same
from tinygrad.shape.symbolic import Variable, sint
if TYPE_CHECKING:
@@ -310,18 +310,6 @@ python_alu: Dict[Op, Callable] = {
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
def truncate_fp16(x):
try: return struct.unpack("@e", struct.pack("@e", float(x)))[0]
except OverflowError: return math.copysign(math.inf, x)
truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, dtypes.int32: lambda x: ctypes.c_int32(x).value \
if isinstance(x,int) else x, dtypes.int64: lambda x: ctypes.c_int64(x).value}
def exec_alu(op:Op, dtype:DType, operands):
if dtype.count > 1:
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])

View File

@@ -3164,7 +3164,7 @@ class Tensor:
"""
assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]"
assert reduction in ("mean", "sum", "none"), "reduction must be one of ['mean', 'sum', 'none']"
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index)
log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) if ignore_index != -1 else Y.ones_like(dtype=dtypes.bool)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
smoothing = label_smoothing * (log_probs.mean(-1) * loss_mask)