mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
Tensor.keccak("sha3_256") (#7186)
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> Co-authored-by: George Hotz <geohot@gmail.com> Co-authored-by: wozeparrot <wozeparrot@gmail.com>
This commit is contained in:
20
test/external/external_benchmark_keccak.py
vendored
Normal file
20
test/external/external_benchmark_keccak.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.helpers import Timing, getenv
|
||||
|
||||
if __name__ == "__main__":
|
||||
BS = getenv("BS", 2**14)
|
||||
BLOCKSIZE = getenv("BLOCKSIZE", 4096)
|
||||
HASHFN = getenv("HASHFN", "shake_128")
|
||||
NRUNS = getenv("NRUNS", 5)
|
||||
|
||||
@TinyJit
|
||||
def hasher(data: Tensor): return data.keccak(HASHFN)
|
||||
|
||||
t = Tensor.randn(BS, BLOCKSIZE, dtype=dtypes.uint8).realize()
|
||||
ds_mib = t.nbytes() / 1024**2
|
||||
|
||||
print(f"--- benchmarking (hash: {HASHFN}, data size: {ds_mib} MiB, block size: {BLOCKSIZE} B, batch size: {BS})")
|
||||
for i in range(NRUNS):
|
||||
with Timing(f"run: {i+1}, elapsed time: ", (lambda et: f", throughput: {ds_mib / (et*1e-9):.2f} MiB/s")):
|
||||
hasher(t).realize()
|
||||
31
test/external/external_test_keccak.py
vendored
Normal file
31
test/external/external_test_keccak.py
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
import unittest, zipfile, re
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import fetch, tqdm
|
||||
|
||||
SHA3_URL = "https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/sha3/sha-3bytetestvectors.zip"
|
||||
SHAKE_URL = "https://csrc.nist.gov/CSRC/media/Projects/Cryptographic-Algorithm-Validation-Program/documents/sha3/shakebytetestvectors.zip"
|
||||
|
||||
class TestExternalKeccak(unittest.TestCase):
|
||||
def test_sha3_224(self): self.check_nist_vectors(SHA3_URL, ["SHA3_224LongMsg.rsp", "SHA3_224ShortMsg.rsp"], "sha3_224")
|
||||
def test_sha3_256(self): self.check_nist_vectors(SHA3_URL, ["SHA3_256LongMsg.rsp", "SHA3_256ShortMsg.rsp"], "sha3_256")
|
||||
def test_shake_128(self): self.check_nist_vectors(SHAKE_URL, ["SHAKE128LongMsg.rsp", "SHAKE128ShortMsg.rsp"], "shake_128")
|
||||
|
||||
def check_nist_vectors(self, url: str, filenames: list[str], preset: str):
|
||||
pattern = r"Len\s*=\s*(?P<Len>\d+)\s+Msg\s*=\s*(?P<Msg>[0-9a-fA-F\s]+)\s+(MD|Output)\s*=\s*(?P<Output>[0-9a-fA-F]+)"
|
||||
vecs_zip = fetch(url)
|
||||
|
||||
for filename in filenames:
|
||||
vecs = zipfile.ZipFile(vecs_zip).open(filename).read().decode()
|
||||
|
||||
vectors = [ (l, bytes.fromhex(match["Msg"].lower()), bytes.fromhex(match["Output"].lower()))
|
||||
for match in re.finditer(pattern, vecs) if (l:=int(match["Len"])) < 8192 ]
|
||||
|
||||
self.assertTrue(len(vectors) > 0)
|
||||
|
||||
print("file", filename)
|
||||
for data_len, data, output in tqdm(vectors):
|
||||
tinyout = bytes(Tensor(data[:data_len//8]).keccak(preset).data())
|
||||
self.assertEqual(tinyout, output)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
42
test/test_keccak.py
Normal file
42
test/test_keccak.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing_extensions import Callable
|
||||
import hashlib, random, unittest
|
||||
from tinygrad import Tensor, Device, getenv, dtypes
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uint8) and is_dtype_supported(dtypes.uint64), "Device must support uint8 and uint64")
|
||||
@unittest.skipIf(getenv("MOCKGPU") and Device.DEFAULT == "NV", "crashes in NV CI")
|
||||
class TestKeccak(unittest.TestCase):
|
||||
def setUp(self) -> None: random.seed(1337)
|
||||
|
||||
def test_shape_keeping(self):
|
||||
s = (1, 2, 3, 4)
|
||||
for i in range(len(s)):
|
||||
out_shape = Tensor.randint(*s[i:], high=255, dtype=dtypes.uint8).keccak().shape
|
||||
self.assertTupleEqual(s[i:-1], out_shape[:-1])
|
||||
|
||||
def test_sha3_224(self): self._test_preset("sha3_224", [143, 144])
|
||||
def test_sha3_256(self): self._test_preset("sha3_256", [135, 136])
|
||||
def test_shake_128(self): self._test_preset("shake_128", [167, 168], lambda d: hashlib.shake_128(d).digest(16))
|
||||
|
||||
def _test_preset(self, name: str, special_sizes: list[int], hasher: Callable[[bytes], bytes] | None = None):
|
||||
def default_hasher(d: bytes) -> bytes: return getattr(hashlib, name)(d).digest()
|
||||
if hasher is None: hasher = default_hasher
|
||||
|
||||
for n in (special_sizes + [special_sizes[0] - 1]):
|
||||
a, b = random.randbytes(n), random.randbytes(n)
|
||||
|
||||
ha_ref, hb_ref = hasher(a), hasher(b)
|
||||
tres = Tensor.stack(*(Tensor(d) for d in (a, b))).keccak(name)
|
||||
ha, hb = tres[0].data(), tres[1].data()
|
||||
|
||||
self.assertEqual(ha_ref, ha)
|
||||
self.assertEqual(ha_ref, Tensor(a).keccak(name).data())
|
||||
self.assertEqual(hb_ref, hb)
|
||||
|
||||
def test_abc(self):
|
||||
# https://www.di-mgt.com.au/sha_testvectors.html
|
||||
out = Tensor(b"abc").keccak()
|
||||
self.assertEqual(bytes(out.tolist()), bytearray.fromhex("3a985da74fe225b2 045c172d6bd390bd 855f086e3e9d525b 46bfe24511431532"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1931,6 +1931,56 @@ class Tensor(MathTrait):
|
||||
"""
|
||||
return self.std(axis, keepdim, correction), self.mean(axis, keepdim)
|
||||
|
||||
def keccak(self, cfg:str|tuple[int, int] = "sha3_256"):
|
||||
"""
|
||||
Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default.
|
||||
|
||||
```python exec="false" source="above" session="tensor" result="python"
|
||||
t = Tensor(b"Hello World!").keccak()
|
||||
print(t.data().hex())
|
||||
```
|
||||
"""
|
||||
|
||||
# https://keccak.team/keccak_specs_summary.html
|
||||
|
||||
def ctensor(l: Sequence[ConstType], dtype: DType = dtypes.uint64): return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l))
|
||||
rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
|
||||
rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets])
|
||||
|
||||
# calculated from π step
|
||||
reorder_indexes = [0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21]
|
||||
rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081,
|
||||
0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
|
||||
0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)]
|
||||
|
||||
rate, dsbyte = { "sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31) }[cfg] if isinstance(cfg, str) else cfg
|
||||
data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), -1), rate - (self.shape[-1] * self.dtype.itemsize % rate)
|
||||
# pad batches then pad blocks
|
||||
data = data.pad((None, (0, data_pad))).reshape(data.shape[0], -1, rate).pad((None, None, (0, 200 - rate))).flatten(1)
|
||||
|
||||
# create pad mask
|
||||
lbe = data.shape[1] + rate - data_pad - 200
|
||||
if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (data.shape[-1] - lbe - 1, 0)]
|
||||
else: mb = [(lbe, 0), (1, dsbyte), (data.shape[-1] + rate - lbe - 202, 0), (1, 0x80), (200 - rate, 0)]
|
||||
pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0))
|
||||
|
||||
data = (data ^ pad_mask).reshape(data.shape[0], -1, 200).bitcast(dtypes.uint64)
|
||||
|
||||
state = Tensor.zeros((data.shape[0], 25), device=self.device, dtype=dtypes.uint64)
|
||||
for k in range(int(data.shape[1])):
|
||||
state = state.bitwise_xor(data[:,k].reshape(-1, 25))
|
||||
for i in range(24): # f1600
|
||||
# θ step
|
||||
p = state.reshape((-1, 5, 5)).transpose(2, 1)
|
||||
t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
|
||||
state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand((-1, -1, 5)).transpose(2, 1).flatten(1))
|
||||
# ρ and π steps
|
||||
state = state[:,reorder_indexes]
|
||||
state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape((-1, 5, 5))
|
||||
# χ and ι step
|
||||
state = state.bitwise_xor((state.roll(shifts=-1, dims=2) ^ -1) & state.roll(shifts=-2, dims=2)).flatten(1) ^ rnd_const_masks[i]
|
||||
return state.bitcast(dtypes.uint8)[:,:(200 - rate) // 2].reshape(*self.shape[:-1], -1)
|
||||
|
||||
def _softmax(self, axis, dtype:DTypeLike|None=None) -> tuple[Tensor, Tensor, Tensor]:
|
||||
m = self - self.max(axis=axis, keepdim=True).detach()
|
||||
if dtype is not None: m = m.cast(dtype)
|
||||
|
||||
Reference in New Issue
Block a user