cleanup mlops (#1521)

* cleanup mlops

* that line belongs there
This commit is contained in:
George Hotz
2023-08-10 19:53:28 -07:00
committed by GitHub
parent 47f18f4d60
commit 38fe84d92b
4 changed files with 53 additions and 42 deletions

View File

@@ -1143,9 +1143,9 @@ class TestOps(unittest.TestCase):
self.helper_test_exception([], lambda: tor[tb,:,:,:,:].sum().backward(), lambda: ten.gather(ta, dim=0).sum().backward(), expected=(IndexError, RuntimeError)) # torch raises IndexError, Tensor raises RuntimeError
def test_scaled_product_attention(self):
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64), (32,8,128,128)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
helper_test_op([(32,8,128,64), (32,8,128,64), (32,8,128,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
if __name__ == '__main__':
np.random.seed(1337)