mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
least_upper_float is at least default_float (#9303)
* least_upper_float is at least default_float en route for div rounding mode. dtype of true int division would change from int32 to default_float, which matches torch too. * fix bert acc
This commit is contained in:
@@ -72,7 +72,7 @@ class BertForPretraining:
|
||||
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
|
||||
|
||||
# TODO: is it okay that next_sentence_loss is half here?
|
||||
return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
|
||||
return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
|
||||
|
||||
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
|
||||
|
||||
@@ -633,16 +633,22 @@ class TestTypePromotion(unittest.TestCase):
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
|
||||
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
|
||||
|
||||
@given(strat.sampled_from(dtype_floats))
|
||||
def test_float_to_float(self, dt):
|
||||
assert least_upper_float(dt) == dt
|
||||
|
||||
class TestAutoCastType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
|
||||
def tearDown(self):
|
||||
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
|
||||
|
||||
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), input_dtype)
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(least_upper_float(input_dtype), default_float)
|
||||
|
||||
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
|
||||
def test_int_to_float_unary_func(self, dtype):
|
||||
for func in [
|
||||
@@ -667,6 +673,11 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
|
||||
|
||||
@given(strat.sampled_from(dtype_floats))
|
||||
def test_int_div_int(self, default_float):
|
||||
dtypes.default_float = default_float
|
||||
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
|
||||
|
||||
def test_sum(self):
|
||||
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
|
||||
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32
|
||||
|
||||
@@ -169,7 +169,7 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
|
||||
@functools.lru_cache(None)
|
||||
def least_upper_dtype(*ds:DType) -> DType:
|
||||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
|
||||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
|
||||
|
||||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
|
||||
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}
|
||||
|
||||
Reference in New Issue
Block a user