variable bs keccak (#10731)

This commit is contained in:
wozeparrot
2025-10-23 14:10:21 -07:00
committed by GitHub
parent 154b4f9f40
commit 9dac505565
2 changed files with 50 additions and 1 deletions

View File

@@ -3,6 +3,8 @@ import hashlib, random, unittest
from tinygrad import Tensor, Device, getenv, dtypes
from tinygrad.device import is_dtype_supported
from tinygrad.helpers import CI
from tinygrad.uop.ops import UOp
from tinygrad.engine.jit import TinyJit
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashes in NV CI")
@@ -72,5 +74,52 @@ class TestKeccak(unittest.TestCase):
data = b"\x00" * 1000
self.assertEqual(bytes(Tensor(data).keccak("shake_128").tolist()), hashlib.shake_128(data).digest(16))
def test_variable_bs(self):
data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1)
bs = UOp.variable("bs", 1, 4096).bind(1)
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(1, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
bs = UOp.variable("bs", 1, 4096).bind(2)
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(2, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
bs = UOp.variable("bs", 1, 4096).bind(3)
data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1)
out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(3, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df"))
def test_variable_bs_jit(self):
def f(data):
return data.keccak()
jit_f = TinyJit(f)
data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1)
# initialize jit
for _ in range(3):
bs = UOp.variable("bs", 1, 4096).bind(4096)
_ = jit_f(data.shrink_to(bs, data.shape[-1]))
bs = UOp.variable("bs", 1, 4096).bind(1)
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(1, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
bs = UOp.variable("bs", 1, 4096).bind(2)
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(2, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
bs = UOp.variable("bs", 1, 4096).bind(3)
data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1)
out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(3, 32)
self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df"))
if __name__ == "__main__":
unittest.main()