From ed4416ba348cae92b1326c357b1c6e6f946e095e Mon Sep 17 00:00:00 2001 From: protolambda Date: Sat, 25 May 2019 00:05:03 +0200 Subject: [PATCH] update SSZ implementation --- .../pyspec/eth2spec/utils/ssz/ssz_impl.py | 94 ++++-- .../pyspec/eth2spec/utils/ssz/ssz_switch.py | 105 ------- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 284 ++++++++++++++---- 3 files changed, 291 insertions(+), 192 deletions(-) delete mode 100644 test_libs/pyspec/eth2spec/utils/ssz/ssz_switch.py diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index 74bc1bf99..2b0328bfa 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -1,5 +1,27 @@ from eth2spec.utils.merkle_minimal import merkleize_chunks -from .ssz_switch import * +from .ssz_typing import * + + +# SSZ Defaults +# ----------------------------- + + +def get_zero_value(typ): + if is_uint(typ): + return 0 + if issubclass(typ, bool): + return False + if issubclass(typ, list): + return [] + if issubclass(typ, Vector): + return typ() + if issubclass(typ, BytesN): + return typ() + if issubclass(typ, bytes): + return b'' + if issubclass(typ, SSZContainer): + return typ(**{f: get_zero_value(t) for f, t in typ.get_fields().items()}), + # SSZ Helpers # ----------------------------- @@ -14,23 +36,10 @@ def chunkify(byte_string): return [byte_string[i:i + 32] for i in range(0, len(byte_string), 32)] -BYTES_PER_LENGTH_OFFSET = 4 - - -# SSZ Implementation +# SSZ Serialization # ----------------------------- -get_zero_value = ssz_type_switch({ - ssz_bool: lambda: False, - ssz_uint: lambda: 0, - ssz_list: lambda byte_form: b'' if byte_form else [], - ssz_vector: lambda length, elem_typ, byte_form: - (b'\x00' * length if length > 0 else b'') if byte_form else - [get_zero_value(elem_typ) for _ in range(length)], - ssz_container: lambda typ, field_names, field_types: - typ(**{f_name: get_zero_value(f_typ) for f_name, f_typ in zip(field_names, field_types)}), -}) - +BYTES_PER_LENGTH_OFFSET = 4 serialize = ssz_switch({ ssz_bool: lambda value: b'\x01' if value else b'\x00', @@ -40,13 +49,6 @@ serialize = ssz_switch({ ssz_container: lambda value, get_field_values, field_types: encode_series(get_field_values(value), field_types), }) -ssz_basic_type = (ssz_bool, ssz_uint) - -is_basic_type = ssz_type_switch({ - ssz_basic_type: lambda: True, - ssz_default: lambda: False, -}) - is_fixed_size = ssz_type_switch({ ssz_basic_type: lambda: True, ssz_vector: lambda elem_typ: is_fixed_size(elem_typ), @@ -55,6 +57,27 @@ is_fixed_size = ssz_type_switch({ }) +# SSZ Hash-tree-root +# ----------------------------- + +def serialize_basic(value, typ): + if is_uint(typ): + return value.to_bytes(typ.byte_len, 'little') + if issubclass(typ, bool): + if value: + return b'\x01' + else: + return b'\x00' + + +def pack(values, subtype): + return b''.join([serialize_basic(value, subtype) for value in values]) + + +def is_basic_type(typ): + return is_uint(typ) or issubclass(typ, bool) + + def hash_tree_root_list(value, elem_typ): if is_basic_type(elem_typ): return merkleize_chunks(chunkify(pack(value, elem_typ))) @@ -77,10 +100,9 @@ hash_tree_root = ssz_switch({ ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)), }) -signing_root = ssz_switch({ - ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)[:-1]), - ssz_default: lambda value, typ: hash_tree_root(value, typ), -}) +# todo: signing root +def signing_root(value, typ): + pass def encode_series(values, types): @@ -115,3 +137,21 @@ def encode_series(values, types): # Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts return b''.join(fixed_parts + variable_parts) + + +# Implementation notes: +# - SSZContainer,Vector/BytesN.hash_tree_root/serialize functions are for ease, implementation here +# - uint types have a 'byte_len' attribute +# - uint types are not classes. They use NewType(), for performance. +# This forces us to check type equivalence by exact reference. +# There's no class. The type data comes from an annotation/argument from the context of the value. +# - Vector is not valid to create instances with. Give it a elem-type and length: Vector[FooBar, 123] +# - *The class of* a Vector instance has a `elem_type` (type, may not be a class, see uint) and `length` (int) +# - BytesN is not valid to create instances with. Give it a length: BytesN[123] +# - *The class of* a BytesN instance has a `length` (int) +# Where possible, it is preferable to create helpers that just act on the type, and don't unnecessarily use a value +# E.g. is_basic_type(). This way, we can use them in type-only contexts and have no duplicate logic. +# For every class-instance, you can get the type with my_object.__class__ +# For uints, and other NewType related, you have to rely on type information. It cannot be retrieved from the value. +# Note: we may just want to box integers instead. And then we can do bounds checking too. But it is SLOW and MEMORY INTENSIVE. +# diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_switch.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_switch.py deleted file mode 100644 index 3da1e7cb1..000000000 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_switch.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Dict, Any - -from .ssz_typing import * - -# SSZ Switch statement runner factory -# ----------------------------- - - -def ssz_switch(sw: Dict[Any, Any], arg_names=None): - """ - Creates an SSZ switch statement: a function, that when executed, checks every switch-statement - """ - if arg_names is None: - arg_names = ["value", "typ"] - - # Runner, the function that executes the switch when called. - # Accepts a arguments based on the arg_names declared in the ssz_switch. - def run_switch(*args): - # value may be None - value = None - try: - value = args[arg_names.index("value")] - except ValueError: - pass # no value argument - - # typ may be None when value is not None - typ = None - try: - typ = args[arg_names.index("typ")] - except ValueError: - # no typ argument expected - pass - except IndexError: - # typ argument expected, but not passed. Try to get it from the class info - typ = value.__class__ - if hasattr(typ, '__forward_arg__'): - typ = typ.__forward_arg__ - - # Now, go over all switch cases - for matchers, worker in sw.items(): - if not isinstance(matchers, tuple): - matchers = (matchers,) - # for each matcher of the case key - for m in matchers: - data = m(typ) - # if we have data, the matcher matched, and we can return the result - if data is not None: - # Supply value and type by default, and any data presented by the matcher. - kwargs = {"value": value, "typ": typ, **data} - # Filter out unwanted arguments - filtered_kwargs = {k: kwargs[k] for k in worker.__code__.co_varnames} - # run the switch case and return result - return worker(**filtered_kwargs) - raise Exception("cannot find matcher for type: %s (value: %s)" % (typ, value)) - return run_switch - - -def ssz_type_switch(sw: Dict[Any, Any]): - return ssz_switch(sw, ["typ"]) - - -# SSZ Switch matchers -# ----------------------------- - -def ssz_bool(typ): - if typ == bool: - return {} - - -def ssz_uint(typ): - # Note: only the type reference exists, - # but it really resolves to 'int' during run-time for zero computational/memory overhead. - # Hence, we check equality to the type references (which are really just 'NewType' instances), - # and don't use any sub-classing like we normally would. - if typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64\ - or typ == uint128 or typ == uint256 or typ == byte: - return {"byte_len": typ.byte_len} - - -def ssz_list(typ): - if hasattr(typ, '__bases__') and List in typ.__bases__: - return {"elem_typ": read_list_elem_typ(typ), "byte_form": False} - if typ == bytes: - return {"elem_typ": uint8, "byte_form": True} - - -def ssz_vector(typ): - if hasattr(typ, '__bases__'): - if Vector in typ.__bases__: - return {"elem_typ": read_vec_elem_typ(typ), "length": read_vec_len(typ), "byte_form": False} - if BytesN in typ.__bases__: - return {"elem_typ": uint8, "length": read_bytesN_len(typ), "byte_form": True} - - -def ssz_container(typ): - if hasattr(typ, '__bases__') and SSZContainer in typ.__bases__: - def get_field_values(value): - return [getattr(value, field) for field in typ.__annotations__.keys()] - field_names = list(typ.__annotations__.keys()) - field_types = list(typ.__annotations__.values()) - return {"get_field_values": get_field_values, "field_names": field_names, "field_types": field_types} - - -def ssz_default(typ): - return {} diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 6a8a22586..dc23427b3 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -1,72 +1,18 @@ -from typing import Generic, List, TypeVar, Type, Iterable, NewType +from typing import List, Iterable, TypeVar, Type, NewType +from typing import Union +from inspect import isclass -# SSZ base length, to limit length generic type param of vector/bytesN -SSZLenAny = type('SSZLenAny', (), {}) - - -def SSZLen(length: int): - """ - SSZ length factory. Creates a type corresponding to a given length. To be used as parameter in type generics. - """ - assert length >= 0 - typ = type('SSZLen_%d' % length, (SSZLenAny,), {}) - typ.length = length - return typ - - -# SSZ element type T = TypeVar('T') -# SSZ vector/bytesN length -L = TypeVar('L', bound=SSZLenAny) - - -# SSZ vector -# ----------------------------- - -class Vector(Generic[T, L]): - def __init__(self, *args: Iterable[T]): - self.items = list(args) - - def __getitem__(self, key): - return self.items[key] - - def __setitem__(self, key, value): - self.items[key] = value - - def __iter__(self): - return iter(self.items) - - def __len__(self): - return len(self.items) - - -def read_vec_elem_typ(vec_typ: Type[Vector[T,L]]) -> T: - assert vec_typ.__args__ is not None - return vec_typ.__args__[0] - - -def read_vec_len(vec_typ: Type[Vector[T,L]]) -> int: - assert vec_typ.__args__ is not None - return vec_typ.__args__[1].length - # SSZ list # ----------------------------- + + def read_list_elem_typ(list_typ: Type[List[T]]) -> T: assert list_typ.__args__ is not None return list_typ.__args__[0] -# SSZ bytesN -# ----------------------------- -class BytesN(Generic[L]): - pass - - -def read_bytesN_len(bytesN_typ: Type[BytesN[L]]) -> int: - assert bytesN_typ.__args__ is not None - return bytesN_typ.__args__[0].length - # SSZ integer types, with 0 computational overhead (NewType) # ----------------------------- @@ -94,8 +40,9 @@ byte = NewType('byte', uint8) class SSZContainer(object): def __init__(self, **kwargs): + cls = self.__class__ from .ssz_impl import get_zero_value - for f, t in self.__annotations__.items(): + for f, t in cls.get_fields().items(): if f not in kwargs: setattr(self, f, get_zero_value(t)) else: @@ -113,3 +60,220 @@ class SSZContainer(object): from .ssz_impl import signing_root return signing_root(self, self.__class__) + def get_field_values(self): + cls = self.__class__ + return [getattr(self, field) for field in cls.get_field_names()] + + @classmethod + def get_fields(cls): + return dict(cls.__annotations__) + + @classmethod + def get_field_names(cls): + return list(cls.__annotations__.keys()) + + @classmethod + def get_field_types(cls): + # values of annotations are the types corresponding to the fields, not instance values. + return list(cls.__annotations__.values()) + + +def is_uint(typ): + # Note: only the type reference exists, + # but it really resolves to 'int' during run-time for zero computational/memory overhead. + # Hence, we check equality to the type references (which are really just 'NewType' instances), + # and don't use any sub-classing like we normally would. + return typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64 \ + or typ == uint128 or typ == uint256 or typ == byte + +# SSZ vector +# ----------------------------- + + +def _is_vector_instance_of(a, b): + if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): + # Vector (b) is not an instance of Vector[X, Y] (a) + return False + if not hasattr(a, 'elem_type') or not hasattr(a, 'length'): + # Vector[X, Y] (b) is an instance of Vector (a) + return True + + # Vector[X, Y] (a) is an instance of Vector[X, Y] (b) + return a.elem_type == b.elem_type and a.length == b.length + + +def _is_equal_vector_type(a, b): + if not hasattr(a, 'elem_type') or not hasattr(a, 'length'): + if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): + # Vector == Vector + return True + # Vector != Vector[X, Y] + return False + if not hasattr(b, 'elem_type') or not hasattr(b, 'length'): + # Vector[X, Y] != Vector + return False + # Vector[X, Y] == Vector[X, Y] + return a.elem_type == b.elem_type and a.length == b.length + + +class VectorMeta(type): + def __new__(cls, class_name, parents, attrs): + out = type.__new__(cls, class_name, parents, attrs) + if 'elem_type' in attrs and 'length' in attrs: + setattr(out, 'elem_type', attrs['elem_type']) + setattr(out, 'length', attrs['length']) + return out + + def __getitem__(self, params): + return self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]}) + + def __subclasscheck__(self, sub): + return _is_vector_instance_of(self, sub) + + def __instancecheck__(self, other): + return _is_vector_instance_of(self, other.__class__) + + def __eq__(self, other): + return _is_equal_vector_type(self, other) + + def __ne__(self, other): + return not _is_equal_vector_type(self, other) + + +class Vector(metaclass=VectorMeta): + + def __init__(self, *args: Iterable[T]): + + cls = self.__class__ + if not hasattr(cls, 'elem_type'): + raise TypeError("Type Vector without elem_type data cannot be instantiated") + if not hasattr(cls, 'length'): + raise TypeError("Type Vector without length data cannot be instantiated") + + if len(args) != cls.length: + if len(args) == 0: + from .ssz_impl import get_zero_value + args = [get_zero_value(cls.elem_type) for _ in range(cls.length)] + else: + raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args))) + + self.items = list(args) + + # cannot check non-class objects + if isclass(cls.elem_type): + for i, item in enumerate(self.items): + if not isinstance(item, cls.elem_type): + raise TypeError("Typed vector cannot hold differently typed value" + " at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type)) + + def serialize(self): + from .ssz_impl import serialize + return serialize(self, self.__class__) + + def hash_tree_root(self): + from .ssz_impl import hash_tree_root + return hash_tree_root(self, self.__class__) + + def __getitem__(self, key): + return self.items[key] + + def __setitem__(self, key, value): + self.items[key] = value + + def __iter__(self): + return iter(self.items) + + def __len__(self): + return len(self.items) + + +def _is_bytes_n_instance_of(a, b): + if not hasattr(b, 'length'): + # BytesN (b) is not an instance of BytesN[X] (a) + return False + if not hasattr(a, 'length'): + # BytesN[X] (b) is an instance of BytesN (a) + return True + + # BytesN[X] (a) is an instance of BytesN[X] (b) + return a.length == b.length + + +def _is_equal_bytes_n_type(a, b): + if not hasattr(a, 'length'): + if not hasattr(b, 'length'): + # BytesN == BytesN + return True + # BytesN != BytesN[X] + return False + if not hasattr(b, 'length'): + # BytesN[X] != BytesN + return False + # BytesN[X] == BytesN[X] + return a.length == b.length + + +class BytesNMeta(type): + def __new__(cls, class_name, parents, attrs): + out = type.__new__(cls, class_name, parents, attrs) + if 'length' in attrs: + setattr(out, 'length', attrs['length']) + return out + + def __getitem__(self, n): + return self.__class__(self.__name__, (BytesN,), {'length': n}) + + def __subclasscheck__(self, sub): + return _is_bytes_n_instance_of(self, sub) + + def __instancecheck__(self, other): + return _is_bytes_n_instance_of(self, other.__class__) + + def __eq__(self, other): + return _is_equal_bytes_n_type(self, other) + + def __ne__(self, other): + return not _is_equal_bytes_n_type(self, other) + + +def parse_bytes(val): + if val is None: + return None + if isinstance(val, str): + # TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val) + return None + if isinstance(val, bytes): + return val + if isinstance(val, int): + return bytes([val]) + return None + + +class BytesN(bytes, metaclass=BytesNMeta): + def __new__(cls, *args): + if not hasattr(cls, 'length'): + return + bytesval = None + if len(args) == 1: + val: Union[bytes, int, str] = args[0] + bytesval = parse_bytes(val) + elif len(args) > 1: + # TODO: each int is 1 byte, check size, create bytesval + bytesval = bytes(args) + + if bytesval is None: + if cls.length == 0: + bytesval = b'' + else: + bytesval = b'\x00' * cls.length + if len(bytesval) != cls.length: + raise TypeError("bytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval))) + return super().__new__(cls, bytesval) + + def serialize(self): + from .ssz_impl import serialize + return serialize(self, self.__class__) + + def hash_tree_root(self): + from .ssz_impl import hash_tree_root + return hash_tree_root(self, self.__class__)