least_upper_float is at least default_float (#9303)

* least_upper_float is at least default_float

en route for div rounding mode. dtype of true int division would change from int32 to default_float, which matches torch too.

* fix bert acc
This commit is contained in:
chenyu
2025-02-28 10:41:56 -05:00
committed by GitHub
parent 3210b656b6
commit 3ae66e59a3
3 changed files with 17 additions and 6 deletions

View File

@@ -72,7 +72,7 @@ class BertForPretraining:
next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
# TODO: is it okay that next_sentence_loss is half here?
return masked_lm_correct.sum() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()
def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info

View File

@@ -633,16 +633,22 @@ class TestTypePromotion(unittest.TestCase):
assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16
assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16
@given(strat.sampled_from(dtype_floats))
def test_float_to_float(self, dt):
assert least_upper_float(dt) == dt
class TestAutoCastType(unittest.TestCase):
def setUp(self):
self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float
def tearDown(self):
dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float
@given(strat.sampled_from(dtype_floats), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_float(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), input_dtype)
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_least_upper_float_input_is_int(self, input_dtype, default_float):
dtypes.default_float = default_float
self.assertEqual(least_upper_float(input_dtype), default_float)
@given(strat.sampled_from([d for d in core_dtypes if dtypes.is_int(d) and is_dtype_supported(d)]))
def test_int_to_float_unary_func(self, dtype):
for func in [
@@ -667,6 +673,11 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt
@given(strat.sampled_from(dtype_floats))
def test_int_div_int(self, default_float):
dtypes.default_float = default_float
self.assertEqual(Tensor([1]).div(Tensor([2])).dtype, default_float)
def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32
assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32

View File

@@ -169,7 +169,7 @@ def _get_recursive_parents(dtype:DType) -> set[DType]:
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.default_float)
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if isinstance(v, DType) and not k.startswith(("default", "void"))}
INVERSE_DTYPES_DICT = {**{v.name:k for k,v in DTYPES_DICT.items()}, "void": "void"}