From 9dac505565634cdc079f1c1742862970ff637e91 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Thu, 23 Oct 2025 14:10:21 -0700 Subject: [PATCH] variable bs keccak (#10731) --- test/unit/test_hashing.py | 49 +++++++++++++++++++++++++++++++++++++++ tinygrad/tensor.py | 2 +- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/test/unit/test_hashing.py b/test/unit/test_hashing.py index 1fd5b6f8d3..d35fac435c 100644 --- a/test/unit/test_hashing.py +++ b/test/unit/test_hashing.py @@ -3,6 +3,8 @@ import hashlib, random, unittest from tinygrad import Tensor, Device, getenv, dtypes from tinygrad.device import is_dtype_supported from tinygrad.helpers import CI +from tinygrad.uop.ops import UOp +from tinygrad.engine.jit import TinyJit @unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64") @unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashes in NV CI") @@ -72,5 +74,52 @@ class TestKeccak(unittest.TestCase): data = b"\x00" * 1000 self.assertEqual(bytes(Tensor(data).keccak("shake_128").tolist()), hashlib.shake_128(data).digest(16)) + def test_variable_bs(self): + data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1) + + bs = UOp.variable("bs", 1, 4096).bind(1) + out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(1, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + + bs = UOp.variable("bs", 1, 4096).bind(2) + out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(2, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + + bs = UOp.variable("bs", 1, 4096).bind(3) + data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1) + out = data.shrink_to(bs, data.shape[-1]).keccak().shrink_to(3, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df")) + + def test_variable_bs_jit(self): + def f(data): + return data.keccak() + jit_f = TinyJit(f) + + data = Tensor([b"abc", b"abc", b"abc"], dtype=dtypes.uint8).repeat(2048, 1) + + # initialize jit + for _ in range(3): + bs = UOp.variable("bs", 1, 4096).bind(4096) + _ = jit_f(data.shrink_to(bs, data.shape[-1])) + + bs = UOp.variable("bs", 1, 4096).bind(1) + out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(1, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + + bs = UOp.variable("bs", 1, 4096).bind(2) + out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(2, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + + bs = UOp.variable("bs", 1, 4096).bind(3) + data = Tensor([b"abc", b"abc", b"def"], dtype=dtypes.uint8).repeat(2048, 1) + out = jit_f(data.shrink_to(bs, data.shape[-1])).shrink_to(3, 32) + self.assertEqual(bytes(out[0].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[1].tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + self.assertEqual(bytes(out[2].tolist()), bytearray.fromhex("8e0d8f672252acb0 ffc5093db8653b18 1513bf9a2097e737 b4f73533dcaf46df")) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 95ed3397e7..c26360852d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2090,7 +2090,7 @@ class Tensor(MathTrait): state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64) for k in range(int(data.shape[1])): - state = state.bitwise_xor(data[:,k].reshape(bs, 25)) + state = state ^ data.shrink((None, (k, k+1), None)).squeeze(1) for i in range(24): # f1600 # θ step p = state.reshape(bs, 5, 5).transpose(2, 1)