From eb7305e6a4d981ef2a734b1a5f53d6efa66e0504 Mon Sep 17 00:00:00 2001 From: leopf <43857362+leopf@users.noreply.github.com> Date: Sat, 7 Jun 2025 00:24:05 +0200 Subject: [PATCH] Tensor.keccak("sha3_256") (#7186) Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: George Hotz Co-authored-by: wozeparrot --- test/external/external_benchmark_keccak.py | 20 +++++++++ test/external/external_test_keccak.py | 31 ++++++++++++++ test/test_keccak.py | 42 ++++++++++++++++++ tinygrad/tensor.py | 50 ++++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 test/external/external_benchmark_keccak.py create mode 100644 test/external/external_test_keccak.py create mode 100644 test/test_keccak.py diff --git a/test/external/external_benchmark_keccak.py b/test/external/external_benchmark_keccak.py new file mode 100644 index 0000000000..1365ca9b74 --- /dev/null +++ b/test/external/external_benchmark_keccak.py @@ -0,0 +1,20 @@ +from tinygrad import Tensor, dtypes +from tinygrad.engine.jit import TinyJit +from tinygrad.helpers import Timing, getenv + +if __name__ == "__main__": + BS = getenv("BS", 2**14) + BLOCKSIZE = getenv("BLOCKSIZE", 4096) + HASHFN = getenv("HASHFN", "shake_128") + NRUNS = getenv("NRUNS", 5) + + @TinyJit + def hasher(data: Tensor): return data.keccak(HASHFN) + + t = Tensor.randn(BS, BLOCKSIZE, dtype=dtypes.uint8).realize() + ds_mib = t.nbytes() / 1024**2 + + print(f"--- benchmarking (hash: {HASHFN}, data size: {ds_mib} MiB, block size: {BLOCKSIZE} B, batch size: {BS})") + for i in range(NRUNS): + with Timing(f"run: {i+1}, elapsed time: ", (lambda et: f", throughput: {ds_mib / (et*1e-9):.2f} MiB/s")): + hasher(t).realize() diff --git a/test/external/external_test_keccak.py b/test/external/external_test_keccak.py new file mode 100644 index 0000000000..bccffe8be4 --- /dev/null +++ b/test/external/external_test_keccak.py @@ -0,0 +1,31 @@ +import unittest, zipfile, re +from tinygrad import Tensor +from tinygrad.helpers import fetch, tqdm + +SHA3_URL = "https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/sha3/sha-3bytetestvectors.zip" +SHAKE_URL = "https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/sha3/shakebytetestvectors.zip" + +class TestExternalKeccak(unittest.TestCase): + def test_sha3_224(self): self.check_nist_vectors(SHA3_URL, ["SHA3_224LongMsg.rsp", "SHA3_224ShortMsg.rsp"], "sha3_224") + def test_sha3_256(self): self.check_nist_vectors(SHA3_URL, ["SHA3_256LongMsg.rsp", "SHA3_256ShortMsg.rsp"], "sha3_256") + def test_shake_128(self): self.check_nist_vectors(SHAKE_URL, ["SHAKE128LongMsg.rsp", "SHAKE128ShortMsg.rsp"], "shake_128") + + def check_nist_vectors(self, url: str, filenames: list[str], preset: str): + pattern = r"Len\s*=\s*(?P\d+)\s+Msg\s*=\s*(?P[0-9a-fA-F\s]+)\s+(MD|Output)\s*=\s*(?P[0-9a-fA-F]+)" + vecs_zip = fetch(url) + + for filename in filenames: + vecs = zipfile.ZipFile(vecs_zip).open(filename).read().decode() + + vectors = [ (l, bytes.fromhex(match["Msg"].lower()), bytes.fromhex(match["Output"].lower())) + for match in re.finditer(pattern, vecs) if (l:=int(match["Len"])) < 8192 ] + + self.assertTrue(len(vectors) > 0) + + print("file", filename) + for data_len, data, output in tqdm(vectors): + tinyout = bytes(Tensor(data[:data_len//8]).keccak(preset).data()) + self.assertEqual(tinyout, output) + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_keccak.py b/test/test_keccak.py new file mode 100644 index 0000000000..c32a74f8d6 --- /dev/null +++ b/test/test_keccak.py @@ -0,0 +1,42 @@ +from typing_extensions import Callable +import hashlib, random, unittest +from tinygrad import Tensor, Device, getenv, dtypes +from tinygrad.device import is_dtype_supported + +@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64") +@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashes in NV CI") +class TestKeccak(unittest.TestCase): + def setUp(self) -> None: random.seed(1337) + + def test_shape_keeping(self): + s = (1, 2, 3, 4) + for i in range(len(s)): + out_shape = Tensor.randint(*s[i:], high=255, dtype=dtypes.uint8).keccak().shape + self.assertTupleEqual(s[i:-1], out_shape[:-1]) + + def test_sha3_224(self): self._test_preset("sha3_224", [143, 144]) + def test_sha3_256(self): self._test_preset("sha3_256", [135, 136]) + def test_shake_128(self): self._test_preset("shake_128", [167, 168], lambda d: hashlib.shake_128(d).digest(16)) + + def _test_preset(self, name: str, special_sizes: list[int], hasher: Callable[[bytes], bytes] | None = None): + def default_hasher(d: bytes) -> bytes: return getattr(hashlib, name)(d).digest() + if hasher is None: hasher = default_hasher + + for n in (special_sizes + [special_sizes[0] - 1]): + a, b = random.randbytes(n), random.randbytes(n) + + ha_ref, hb_ref = hasher(a), hasher(b) + tres = Tensor.stack(*(Tensor(d) for d in (a, b))).keccak(name) + ha, hb = tres[0].data(), tres[1].data() + + self.assertEqual(ha_ref, ha) + self.assertEqual(ha_ref, Tensor(a).keccak(name).data()) + self.assertEqual(hb_ref, hb) + + def test_abc(self): + # https://www.di-mgt.com.au/sha_testvectors.html + out = Tensor(b"abc").keccak() + self.assertEqual(bytes(out.tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532")) + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9bbdcc7df0..1bf16a8f4f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1931,6 +1931,56 @@ class Tensor(MathTrait): """ return self.std(axis, keepdim, correction), self.mean(axis, keepdim) + def keccak(self, cfg:str|tuple[int, int] = "sha3_256"): + """ + Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default. + + ```python exec="false" source="above" session="tensor" result="python" + t = Tensor(b"Hello World!").keccak() + print(t.data().hex()) + ``` + """ + + # https://keccak.team/keccak_specs_summary.html + + def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64): return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)) + rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2] + rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets]) + + # calculated from π step + reorder_indexes = [0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21] + rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081, + 0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, + 0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)] + + rate, dsbyte = { "sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31) }[cfg] if isinstance(cfg, str) else cfg + data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), -1), rate - (self.shape[-1] * self.dtype.itemsize % rate) + # pad batches then pad blocks + data = data.pad((None, (0, data_pad))).reshape(data.shape[0], -1, rate).pad((None, None, (0, 200 - rate))).flatten(1) + + # create pad mask + lbe = data.shape[1] + rate - data_pad - 200 + if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (data.shape[-1] - lbe - 1, 0)] + else: mb = [(lbe, 0), (1, dsbyte), (data.shape[-1] + rate - lbe - 202, 0), (1, 0x80), (200 - rate, 0)] + pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)) + + data = (data ^ pad_mask).reshape(data.shape[0], -1, 200).bitcast(dtypes.uint64) + + state = Tensor.zeros((data.shape[0], 25), device=self.device, dtype=dtypes.uint64) + for k in range(int(data.shape[1])): + state = state.bitwise_xor(data[:,k].reshape(-1, 25)) + for i in range(24): # f1600 + # θ step + p = state.reshape((-1, 5, 5)).transpose(2, 1) + t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce + state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand((-1, -1, 5)).transpose(2, 1).flatten(1)) + # ρ and π steps + state = state[:,reorder_indexes] + state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape((-1, 5, 5)) + # χ and ι step + state = state.bitwise_xor((state.roll(shifts=-1, dims=2) ^ -1) & state.roll(shifts=-2, dims=2)).flatten(1) ^ rnd_const_masks[i] + return state.bitcast(dtypes.uint8)[:,:(200 - rate) // 2].reshape(*self.shape[:-1], -1) + def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]: m = self - self.max(axis=axis, keepdim=True).detach() if dtype is not None: m = m.cast(dtype)