mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
update RANGEIFY kernel count and test_masked_select (#12435)
This commit is contained in:
@@ -94,7 +94,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
@TinyJit
|
||||
def test(t, v):
|
||||
with Context(JIT=0): return model(t, v).realize()
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 137 if CI else 396, all_jitted=True)
|
||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 160 if CI else 396, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "CPU", "slow")
|
||||
def test_train_mnist(self):
|
||||
@@ -112,7 +112,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 102)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "CL"}, "slow")
|
||||
def test_forward_cifar(self):
|
||||
@@ -176,7 +176,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
for v in data.values(): v.to_(Device.DEFAULT)
|
||||
|
||||
helper_test("train_bert", lambda: (data["input_ids"], data["segment_ids"], data["input_mask"], data["masked_lm_positions"], \
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 347)
|
||||
data["masked_lm_ids"], data["masked_lm_weights"], data["next_sentence_labels"]), train, 0.25, 357)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -3164,8 +3164,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x<0.1).detach(), -math.inf))
|
||||
|
||||
@unittest.skipIf(RANGEIFY and ((getenv("MOCKGPU") and Device.DEFAULT == "AMD") or Device.DEFAULT == "PYTHON"),
|
||||
"very slow on MOCKGPU because reduce does not fold")
|
||||
@unittest.skipIf(RANGEIFY and (getenv("MOCKGPU") or Device.DEFAULT == "PYTHON"), "very slow on MOCKGPU because reduce does not fold")
|
||||
@unittest.skipIf(RANGEIFY and Device.DEFAULT == "WEBGPU", "webgpu runtime issue")
|
||||
def test_masked_select(self):
|
||||
helper_test_op([(32, 10)], lambda x: x.masked_select(x>0.5), lambda x: x.masked_select(x>0.5), forward_only=True)
|
||||
|
||||
Reference in New Issue
Block a user