# Forked and modified from https://github.com/ajalt/python-sha1, whose license was # # The MIT License (MIT) # # Copyright (c) 2013-2015 AJ Alt # # Permission is hereby granted, free of charge, to any person obtaining a copy of # this software and associated documentation files (the "Software"), to deal in # the Software without restriction, including without limitation the rights to # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of # the Software, and to permit persons to whom the Software is furnished to do so, # subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import io import random import string import struct import time from hashlib import sha1 as hashlib_sha1 import numpy as np from concrete import fhe def _left_rotate(n, b): """Left rotate a 32-bit integer n by b bits.""" return ((n << b) | (n >> (32 - b))) & 0xFFFFFFFF def split(b): """Splitting into bits.""" ans = [] for _ in range(32): ans += [b % 2] b = b // 2 return np.array(ans, dtype=np.int8) def unsplit(bits): """Unsplitting from bits to uint32.""" ans = 0 for i in range(32): ans *= 2 ans += bits[31 - i] return ans def get_random_string(length): """Return a random string.""" if length == 0: return "" # ruff: noqa:S311 result_str = "".join(random.choice(string.ascii_letters) for i in range(length)) return result_str def add_chunked_number(x, y): result = fhe.zeros((32,)) carry = 0 addition = x + y for i in range(32): addition_and_carry = addition[i] + carry carry = addition_and_carry >> 1 result[i] = addition_and_carry - (carry * 2) return result # FHE functions @fhe.module() class MyModule: @staticmethod @fhe.function({"x": "encrypted", "y": "encrypted", "z": "encrypted"}) def xor3(x, y, z): return x ^ y ^ z @staticmethod @fhe.function({"x": "encrypted", "y": "encrypted", "z": "encrypted"}) def iftern(x, y, z): return z ^ (x & (y ^ z)) @staticmethod @fhe.function({"x": "encrypted", "y": "encrypted", "z": "encrypted"}) def maj(x, y, z): return (x & y) | (z & (x | y)) @staticmethod @fhe.function({"x": "encrypted"}) def rotate30(x): ans = fhe.zeros((32,)) ans[30:32] = x[0:2] ans[0:30] = x[2:32] return ans @staticmethod @fhe.function({"x": "encrypted"}) def rotate5(x): ans = fhe.zeros((32,)) ans[5:32] = x[0:27] ans[0:5] = x[27:32] return ans @staticmethod @fhe.function({"x": "encrypted", "y": "encrypted"}) def add2(x, y): return fhe.bits(add_chunked_number(x, y))[0] @staticmethod @fhe.function( {"x": "encrypted", "y": "encrypted", "u": "encrypted", "v": "encrypted", "w": "encrypted"} ) def add5(x, y, u, v, w): result = add_chunked_number(x, y) result = add_chunked_number(result, u) result = add_chunked_number(result, v) result = add_chunked_number(result, w) return fhe.bits(result)[0] # Compilation of the FHE functions size_of_inputsets = 1000 inputset1 = [(np.random.randint(2, size=(32,)),) for _ in range(size_of_inputsets)] inputset2 = [ ( np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), ) for _ in range(size_of_inputsets) ] inputset3 = [ ( np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), ) for _ in range(size_of_inputsets) ] inputset5 = [ ( np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), np.random.randint(2, size=(32,)), ) for _ in range(size_of_inputsets) ] # FIXME: remove the mypy and ruff exceptions once # https://github.com/zama-ai/concrete-internal/issues/721 is fixed # pylint: disable-next=no-member my_module = MyModule.compile( # type: ignore { "xor3": inputset3, "iftern": inputset3, "maj": inputset3, "rotate30": inputset1, "rotate5": inputset1, "add2": inputset2, "add5": inputset5, }, show_mlir=False, bitwise_strategy_preference=fhe.BitwiseStrategy.ONE_TLU_PROMOTED, multivariate_strategy_preference=fhe.MultivariateStrategy.PROMOTED, p_error=10**-8, ) # Split and encrypt on the client side def message_schedule_and_split_and_encrypt(chunk): assert len(chunk) == 64 w = [0] * 80 # Break chunk into sixteen 4-byte big-endian words w[i] for i in range(16): w[i] = struct.unpack(b">I", chunk[i * 4 : i * 4 + 4])[0] # Extend the sixteen 4-byte words into eighty 4-byte words for i in range(16, 80): w[i] = _left_rotate(w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16], 1) # Then split and encrypt wsplit_enc = [0] * 80 for i in range(80): wsplit_enc[i] = my_module.rotate5.encrypt(split(w[i])) return wsplit_enc # Perform SHA computation server side, completely in FHE def _process_encrypted_chunk_server_side( wsplit_enc, h0split_enc, h1split_enc, h2split_enc, h3split_enc, h4split_enc ): """Process a chunk of data and return the new digest variables.""" # Initialize hash value for this chunk asplit_enc = h0split_enc bsplit_enc = h1split_enc csplit_enc = h2split_enc dsplit_enc = h3split_enc esplit_enc = h4split_enc for i in range(80): if 0 <= i <= 19: # Do f = d ^ (b & (c ^ d)) fsplit_enc = my_module.iftern.run(bsplit_enc, csplit_enc, dsplit_enc) ksplit = split(0x5A827999) elif 20 <= i <= 39: # Do f = b ^ c ^ d fsplit_enc = my_module.xor3.run(bsplit_enc, csplit_enc, dsplit_enc) ksplit = split(0x6ED9EBA1) elif 40 <= i <= 59: # Do f = (b & c) | (b & d) | (c & d) fsplit_enc = my_module.maj.run(bsplit_enc, csplit_enc, dsplit_enc) ksplit = split(0x8F1BBCDC) elif 60 <= i <= 79: # Do f = b ^ c ^ d fsplit_enc = my_module.xor3.run(bsplit_enc, csplit_enc, dsplit_enc) ksplit = split(0xCA62C1D6) # Do arot5 = _left_rotate(a, 5) arot5split_enc = my_module.rotate5.run(asplit_enc) # Do arot5 + f + e + k + w[i] ssplit_enc = my_module.add5.run( arot5split_enc, fsplit_enc, esplit_enc, wsplit_enc[i], my_module.rotate5.encrypt(ksplit), # BCM: later remove the encryption on k ) # Final update of the a, b, c, d and e registers newasplit_enc = ssplit_enc esplit_enc = dsplit_enc dsplit_enc = csplit_enc # Do c = _left_rotate(b, 30) csplit_enc = my_module.rotate30.run(bsplit_enc) bsplit_enc = asplit_enc asplit_enc = newasplit_enc # Add this chunk's hash to result so far h0split_enc = my_module.add2.run(h0split_enc, asplit_enc) h1split_enc = my_module.add2.run(h1split_enc, bsplit_enc) h2split_enc = my_module.add2.run(h2split_enc, csplit_enc) h3split_enc = my_module.add2.run(h3split_enc, dsplit_enc) h4split_enc = my_module.add2.run(h4split_enc, esplit_enc) return h0split_enc, h1split_enc, h2split_enc, h3split_enc, h4split_enc class Sha1Hash: """A class that mimics that hashlib api and implements the SHA-1 algorithm.""" name = "python-sha1" digest_size = 20 block_size = 64 def __init__(self): # Initial digest variables h0, h1, h2, h3, h4 = (0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0) # Split h0split = split(h0) h1split = split(h1) h2split = split(h2) h3split = split(h3) h4split = split(h4) # Encrypt h0split_enc = my_module.rotate5.encrypt(h0split) h1split_enc = my_module.rotate5.encrypt(h1split) h2split_enc = my_module.rotate5.encrypt(h2split) h3split_enc = my_module.rotate5.encrypt(h3split) h4split_enc = my_module.rotate5.encrypt(h4split) self._hsplit_enc = (h0split_enc, h1split_enc, h2split_enc, h3split_enc, h4split_enc) # bytes object with 0 <= len < 64 used to store the end of the message # if the message length is not congruent to 64 self._unprocessed = b"" # Length in bytes of all data that has been processed so far self._message_byte_length = 0 def update(self, arg): """Update the current digest. This may be called repeatedly, even after calling digest or hexdigest. Arguments: arg: bytes, bytearray, or BytesIO object to read from. """ if isinstance(arg, (bytes, bytearray)): arg = io.BytesIO(arg) # Try to build a chunk out of the unprocessed data, if any chunk = self._unprocessed + arg.read(64 - len(self._unprocessed)) # Read the rest of the data, 64 bytes at a time while len(chunk) == 64: wsplit_enc = message_schedule_and_split_and_encrypt(chunk) self._hsplit_enc = _process_encrypted_chunk_server_side(wsplit_enc, *self._hsplit_enc) self._message_byte_length += 64 chunk = arg.read(64) self._unprocessed = chunk return self def digest(self): """Produce the final hash value (big-endian) as a bytes object""" return b"".join(struct.pack(b">I", h) for h in self._produce_digest()) def hexdigest(self): """Produce the final hash value (big-endian) as a hex string""" local_digest = self._produce_digest() return ( f"{local_digest[0]:08x}{local_digest[1]:08x}{local_digest[2]:08x}" + f"{local_digest[3]:08x}{local_digest[4]:08x}" ) def _produce_digest(self): """Return finalized digest variables for the data processed so far.""" # Pre-processing: message = self._unprocessed message_byte_length = self._message_byte_length + len(message) # append the bit '1' to the message message += b"\x80" # append 0 <= k < 512 bits '0', so that the resulting message length (in bytes) # is congruent to 56 (mod 64) message += b"\x00" * ((56 - (message_byte_length + 1) % 64) % 64) # append length of message (before pre-processing), in bits, as 64-bit big-endian integer message_bit_length = message_byte_length * 8 message += struct.pack(b">Q", message_bit_length) # Process the final chunk # At this point, the length of the message is either 64 or 128 bytes. wsplit_enc = message_schedule_and_split_and_encrypt(message[:64]) hsplit_enc = _process_encrypted_chunk_server_side(wsplit_enc, *self._hsplit_enc) if len(message) != 64: wsplit_enc = message_schedule_and_split_and_encrypt(message[64:]) hsplit_enc = _process_encrypted_chunk_server_side(wsplit_enc, *hsplit_enc) # Decrypt h0split = my_module.rotate5.decrypt(hsplit_enc[0]) h1split = my_module.rotate5.decrypt(hsplit_enc[1]) h2split = my_module.rotate5.decrypt(hsplit_enc[2]) h3split = my_module.rotate5.decrypt(hsplit_enc[3]) h4split = my_module.rotate5.decrypt(hsplit_enc[4]) # Unsplit h0 = unsplit(h0split) h1 = unsplit(h1split) h2 = unsplit(h2split) h3 = unsplit(h3split) h4 = unsplit(h4split) return h0, h1, h2, h3, h4 def sha1(local_data): """SHA-1 Hashing Function A custom SHA-1 hashing function implemented entirely in Python. Arguments: local_data: A bytes or BytesIO object containing the input message to hash. Returns: A hex SHA-1 digest of the input message. """ return Sha1Hash().update(local_data).hexdigest() def print_timed_sha1(local_data): time_begin = time.time() ans = sha1(local_data) print(f"sha1-digest: {ans}") print(f"computed in: {time.time() - time_begin:2f} seconds") return ans if __name__ == "__main__": # Imports required for command line parsing. No need for these elsewhere import argparse import os import sys # Parse the incoming arguments parser = argparse.ArgumentParser() parser.add_argument("input", nargs="*", help="input file or message to hash") parser.add_argument("--autotest", action="store_true", help="autotest") args = parser.parse_args() if args.autotest: filename = "tmp_sha1_test_file.txt" # Checking random patterns for _ in range(20): string_length = np.random.randint(100) # Take a random string hash_input = get_random_string(string_length) print(f"Checking SHA1({hash_input}) for an input length {string_length}") # Hash it with hashlib_sha1 # ruff: noqa:S324 h = hashlib_sha1() h.update(bytes(hash_input, encoding="utf-8")) expected_ans = h.hexdigest() # Hash it in FHE with open(filename, "w", encoding="utf-8") as file: file.write(f"{hash_input}") with open(filename, "rb") as data: # Show the final digest actual_ans = print_timed_sha1(data) # And compare assert ( actual_ans == expected_ans ), f"Wrong computation: {actual_ans} vs expected {expected_ans} for input {hash_input}" # Checking a few patterns for hash_input, expected_ans in [ ("", "da39a3ee5e6b4b0d3255bfef95601890afd80709"), ( "The quick brown fox jumps over the lazy dog", "2fd4e1c67a2d28fced849ee1bb76e7391b93eb12", ), ]: with open(filename, "w", encoding="utf-8") as file: file.write(f"{hash_input}") print(f"Checking SHA1({hash_input})") with open(filename, "rb") as data: # Show the final digest actual_ans = print_timed_sha1(data) assert ( actual_ans == expected_ans ), f"Wrong computation: {actual_ans} vs expected {expected_ans}" sys.exit(0) if len(args.input) == 0: # No argument given, assume message comes from standard input try: # sys.stdin is opened in text mode, which can change line endings, # leading to incorrect results. Detach fixes this issue, but it's # new in Python 3.1 data = sys.stdin.detach() # type: ignore except AttributeError: # Linux and OSX both use \n line endings, so only windows is a # problem. if sys.platform == "win32": import msvcrt msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY) data = sys.stdin # type: ignore # Output to console print_timed_sha1(data) else: # Loop through arguments list for argument in args.input: if os.path.isfile(argument): # An argument is given and it's a valid file. Read it with open(filename, "rb") as data: # Show the final digest print_timed_sha1(data) else: print("Error, could not find " + argument + " file.")