mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
variable bs keccak (#10731)
This commit is contained in:
@@ -3,6 +3,8 @@ import hashlib, random, unittest
|
||||
from tinygrad import Tensor, Device, getenv, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashes in NV CI")
|
||||
@@ -72,5 +74,52 @@ class TestKeccak(unittest.TestCase):
|
||||
data = b"\x00" * 1000
|
||||
self.assertEqual(bytes(Tensor(data).keccak("shake_128").tolist()), hashlib.shake_128(data).digest(16))
|
||||
|
||||
def test_variable_bs(self):
|
||||
data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(1)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(1, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(2)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(2, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(3)
|
||||
data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(3, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df"))
|
||||
|
||||
def test_variable_bs_jit(self):
|
||||
def f(data):
|
||||
return data.keccak()
|
||||
jit_f = TinyJit(f)
|
||||
|
||||
data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
|
||||
# initialize jit
|
||||
for _ in range(3):
|
||||
bs = UOp.variable("bs", 1, 4096).bind(4096)
|
||||
_ = jit_f(data.shrink_to(bs, data.shape[-1]))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(1)
|
||||
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(1, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(2)
|
||||
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(2, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
bs = UOp.variable("bs", 1, 4096).bind(3)
|
||||
data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1)
|
||||
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(3, 32)
|
||||
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user