mirror of
https://github.com/ethereum/consensus-specs.git
synced 2026-02-02 17:35:03 -05:00
update SSZ implementation
This commit is contained in:
@@ -1,5 +1,27 @@
|
||||
from eth2spec.utils.merkle_minimal import merkleize_chunks
|
||||
from .ssz_switch import *
|
||||
from .ssz_typing import *
|
||||
|
||||
|
||||
# SSZ Defaults
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def get_zero_value(typ):
|
||||
if is_uint(typ):
|
||||
return 0
|
||||
if issubclass(typ, bool):
|
||||
return False
|
||||
if issubclass(typ, list):
|
||||
return []
|
||||
if issubclass(typ, Vector):
|
||||
return typ()
|
||||
if issubclass(typ, BytesN):
|
||||
return typ()
|
||||
if issubclass(typ, bytes):
|
||||
return b''
|
||||
if issubclass(typ, SSZContainer):
|
||||
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields().items()}),
|
||||
|
||||
|
||||
# SSZ Helpers
|
||||
# -----------------------------
|
||||
@@ -14,23 +36,10 @@ def chunkify(byte_string):
|
||||
return [byte_string[i:i + 32] for i in range(0, len(byte_string), 32)]
|
||||
|
||||
|
||||
BYTES_PER_LENGTH_OFFSET = 4
|
||||
|
||||
|
||||
# SSZ Implementation
|
||||
# SSZ Serialization
|
||||
# -----------------------------
|
||||
|
||||
get_zero_value = ssz_type_switch({
|
||||
ssz_bool: lambda: False,
|
||||
ssz_uint: lambda: 0,
|
||||
ssz_list: lambda byte_form: b'' if byte_form else [],
|
||||
ssz_vector: lambda length, elem_typ, byte_form:
|
||||
(b'\x00' * length if length > 0 else b'') if byte_form else
|
||||
[get_zero_value(elem_typ) for _ in range(length)],
|
||||
ssz_container: lambda typ, field_names, field_types:
|
||||
typ(**{f_name: get_zero_value(f_typ) for f_name, f_typ in zip(field_names, field_types)}),
|
||||
})
|
||||
|
||||
BYTES_PER_LENGTH_OFFSET = 4
|
||||
|
||||
serialize = ssz_switch({
|
||||
ssz_bool: lambda value: b'\x01' if value else b'\x00',
|
||||
@@ -40,13 +49,6 @@ serialize = ssz_switch({
|
||||
ssz_container: lambda value, get_field_values, field_types: encode_series(get_field_values(value), field_types),
|
||||
})
|
||||
|
||||
ssz_basic_type = (ssz_bool, ssz_uint)
|
||||
|
||||
is_basic_type = ssz_type_switch({
|
||||
ssz_basic_type: lambda: True,
|
||||
ssz_default: lambda: False,
|
||||
})
|
||||
|
||||
is_fixed_size = ssz_type_switch({
|
||||
ssz_basic_type: lambda: True,
|
||||
ssz_vector: lambda elem_typ: is_fixed_size(elem_typ),
|
||||
@@ -55,6 +57,27 @@ is_fixed_size = ssz_type_switch({
|
||||
})
|
||||
|
||||
|
||||
# SSZ Hash-tree-root
|
||||
# -----------------------------
|
||||
|
||||
def serialize_basic(value, typ):
|
||||
if is_uint(typ):
|
||||
return value.to_bytes(typ.byte_len, 'little')
|
||||
if issubclass(typ, bool):
|
||||
if value:
|
||||
return b'\x01'
|
||||
else:
|
||||
return b'\x00'
|
||||
|
||||
|
||||
def pack(values, subtype):
|
||||
return b''.join([serialize_basic(value, subtype) for value in values])
|
||||
|
||||
|
||||
def is_basic_type(typ):
|
||||
return is_uint(typ) or issubclass(typ, bool)
|
||||
|
||||
|
||||
def hash_tree_root_list(value, elem_typ):
|
||||
if is_basic_type(elem_typ):
|
||||
return merkleize_chunks(chunkify(pack(value, elem_typ)))
|
||||
@@ -77,10 +100,9 @@ hash_tree_root = ssz_switch({
|
||||
ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)),
|
||||
})
|
||||
|
||||
signing_root = ssz_switch({
|
||||
ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)[:-1]),
|
||||
ssz_default: lambda value, typ: hash_tree_root(value, typ),
|
||||
})
|
||||
# todo: signing root
|
||||
def signing_root(value, typ):
|
||||
pass
|
||||
|
||||
|
||||
def encode_series(values, types):
|
||||
@@ -115,3 +137,21 @@ def encode_series(values, types):
|
||||
|
||||
# Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
|
||||
return b''.join(fixed_parts + variable_parts)
|
||||
|
||||
|
||||
# Implementation notes:
|
||||
# - SSZContainer,Vector/BytesN.hash_tree_root/serialize functions are for ease, implementation here
|
||||
# - uint types have a 'byte_len' attribute
|
||||
# - uint types are not classes. They use NewType(), for performance.
|
||||
# This forces us to check type equivalence by exact reference.
|
||||
# There's no class. The type data comes from an annotation/argument from the context of the value.
|
||||
# - Vector is not valid to create instances with. Give it a elem-type and length: Vector[FooBar, 123]
|
||||
# - *The class of* a Vector instance has a `elem_type` (type, may not be a class, see uint) and `length` (int)
|
||||
# - BytesN is not valid to create instances with. Give it a length: BytesN[123]
|
||||
# - *The class of* a BytesN instance has a `length` (int)
|
||||
# Where possible, it is preferable to create helpers that just act on the type, and don't unnecessarily use a value
|
||||
# E.g. is_basic_type(). This way, we can use them in type-only contexts and have no duplicate logic.
|
||||
# For every class-instance, you can get the type with my_object.__class__
|
||||
# For uints, and other NewType related, you have to rely on type information. It cannot be retrieved from the value.
|
||||
# Note: we may just want to box integers instead. And then we can do bounds checking too. But it is SLOW and MEMORY INTENSIVE.
|
||||
#
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
from .ssz_typing import *
|
||||
|
||||
# SSZ Switch statement runner factory
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def ssz_switch(sw: Dict[Any, Any], arg_names=None):
|
||||
"""
|
||||
Creates an SSZ switch statement: a function, that when executed, checks every switch-statement
|
||||
"""
|
||||
if arg_names is None:
|
||||
arg_names = ["value", "typ"]
|
||||
|
||||
# Runner, the function that executes the switch when called.
|
||||
# Accepts a arguments based on the arg_names declared in the ssz_switch.
|
||||
def run_switch(*args):
|
||||
# value may be None
|
||||
value = None
|
||||
try:
|
||||
value = args[arg_names.index("value")]
|
||||
except ValueError:
|
||||
pass # no value argument
|
||||
|
||||
# typ may be None when value is not None
|
||||
typ = None
|
||||
try:
|
||||
typ = args[arg_names.index("typ")]
|
||||
except ValueError:
|
||||
# no typ argument expected
|
||||
pass
|
||||
except IndexError:
|
||||
# typ argument expected, but not passed. Try to get it from the class info
|
||||
typ = value.__class__
|
||||
if hasattr(typ, '__forward_arg__'):
|
||||
typ = typ.__forward_arg__
|
||||
|
||||
# Now, go over all switch cases
|
||||
for matchers, worker in sw.items():
|
||||
if not isinstance(matchers, tuple):
|
||||
matchers = (matchers,)
|
||||
# for each matcher of the case key
|
||||
for m in matchers:
|
||||
data = m(typ)
|
||||
# if we have data, the matcher matched, and we can return the result
|
||||
if data is not None:
|
||||
# Supply value and type by default, and any data presented by the matcher.
|
||||
kwargs = {"value": value, "typ": typ, **data}
|
||||
# Filter out unwanted arguments
|
||||
filtered_kwargs = {k: kwargs[k] for k in worker.__code__.co_varnames}
|
||||
# run the switch case and return result
|
||||
return worker(**filtered_kwargs)
|
||||
raise Exception("cannot find matcher for type: %s (value: %s)" % (typ, value))
|
||||
return run_switch
|
||||
|
||||
|
||||
def ssz_type_switch(sw: Dict[Any, Any]):
|
||||
return ssz_switch(sw, ["typ"])
|
||||
|
||||
|
||||
# SSZ Switch matchers
|
||||
# -----------------------------
|
||||
|
||||
def ssz_bool(typ):
|
||||
if typ == bool:
|
||||
return {}
|
||||
|
||||
|
||||
def ssz_uint(typ):
|
||||
# Note: only the type reference exists,
|
||||
# but it really resolves to 'int' during run-time for zero computational/memory overhead.
|
||||
# Hence, we check equality to the type references (which are really just 'NewType' instances),
|
||||
# and don't use any sub-classing like we normally would.
|
||||
if typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64\
|
||||
or typ == uint128 or typ == uint256 or typ == byte:
|
||||
return {"byte_len": typ.byte_len}
|
||||
|
||||
|
||||
def ssz_list(typ):
|
||||
if hasattr(typ, '__bases__') and List in typ.__bases__:
|
||||
return {"elem_typ": read_list_elem_typ(typ), "byte_form": False}
|
||||
if typ == bytes:
|
||||
return {"elem_typ": uint8, "byte_form": True}
|
||||
|
||||
|
||||
def ssz_vector(typ):
|
||||
if hasattr(typ, '__bases__'):
|
||||
if Vector in typ.__bases__:
|
||||
return {"elem_typ": read_vec_elem_typ(typ), "length": read_vec_len(typ), "byte_form": False}
|
||||
if BytesN in typ.__bases__:
|
||||
return {"elem_typ": uint8, "length": read_bytesN_len(typ), "byte_form": True}
|
||||
|
||||
|
||||
def ssz_container(typ):
|
||||
if hasattr(typ, '__bases__') and SSZContainer in typ.__bases__:
|
||||
def get_field_values(value):
|
||||
return [getattr(value, field) for field in typ.__annotations__.keys()]
|
||||
field_names = list(typ.__annotations__.keys())
|
||||
field_types = list(typ.__annotations__.values())
|
||||
return {"get_field_values": get_field_values, "field_names": field_names, "field_types": field_types}
|
||||
|
||||
|
||||
def ssz_default(typ):
|
||||
return {}
|
||||
@@ -1,72 +1,18 @@
|
||||
from typing import Generic, List, TypeVar, Type, Iterable, NewType
|
||||
from typing import List, Iterable, TypeVar, Type, NewType
|
||||
from typing import Union
|
||||
from inspect import isclass
|
||||
|
||||
# SSZ base length, to limit length generic type param of vector/bytesN
|
||||
SSZLenAny = type('SSZLenAny', (), {})
|
||||
|
||||
|
||||
def SSZLen(length: int):
|
||||
"""
|
||||
SSZ length factory. Creates a type corresponding to a given length. To be used as parameter in type generics.
|
||||
"""
|
||||
assert length >= 0
|
||||
typ = type('SSZLen_%d' % length, (SSZLenAny,), {})
|
||||
typ.length = length
|
||||
return typ
|
||||
|
||||
|
||||
# SSZ element type
|
||||
T = TypeVar('T')
|
||||
# SSZ vector/bytesN length
|
||||
L = TypeVar('L', bound=SSZLenAny)
|
||||
|
||||
|
||||
# SSZ vector
|
||||
# -----------------------------
|
||||
|
||||
class Vector(Generic[T, L]):
|
||||
def __init__(self, *args: Iterable[T]):
|
||||
self.items = list(args)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.items[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.items[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
def read_vec_elem_typ(vec_typ: Type[Vector[T,L]]) -> T:
|
||||
assert vec_typ.__args__ is not None
|
||||
return vec_typ.__args__[0]
|
||||
|
||||
|
||||
def read_vec_len(vec_typ: Type[Vector[T,L]]) -> int:
|
||||
assert vec_typ.__args__ is not None
|
||||
return vec_typ.__args__[1].length
|
||||
|
||||
|
||||
# SSZ list
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def read_list_elem_typ(list_typ: Type[List[T]]) -> T:
|
||||
assert list_typ.__args__ is not None
|
||||
return list_typ.__args__[0]
|
||||
|
||||
|
||||
# SSZ bytesN
|
||||
# -----------------------------
|
||||
class BytesN(Generic[L]):
|
||||
pass
|
||||
|
||||
|
||||
def read_bytesN_len(bytesN_typ: Type[BytesN[L]]) -> int:
|
||||
assert bytesN_typ.__args__ is not None
|
||||
return bytesN_typ.__args__[0].length
|
||||
|
||||
|
||||
# SSZ integer types, with 0 computational overhead (NewType)
|
||||
# -----------------------------
|
||||
@@ -94,8 +40,9 @@ byte = NewType('byte', uint8)
|
||||
class SSZContainer(object):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
cls = self.__class__
|
||||
from .ssz_impl import get_zero_value
|
||||
for f, t in self.__annotations__.items():
|
||||
for f, t in cls.get_fields().items():
|
||||
if f not in kwargs:
|
||||
setattr(self, f, get_zero_value(t))
|
||||
else:
|
||||
@@ -113,3 +60,220 @@ class SSZContainer(object):
|
||||
from .ssz_impl import signing_root
|
||||
return signing_root(self, self.__class__)
|
||||
|
||||
def get_field_values(self):
|
||||
cls = self.__class__
|
||||
return [getattr(self, field) for field in cls.get_field_names()]
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return dict(cls.__annotations__)
|
||||
|
||||
@classmethod
|
||||
def get_field_names(cls):
|
||||
return list(cls.__annotations__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_field_types(cls):
|
||||
# values of annotations are the types corresponding to the fields, not instance values.
|
||||
return list(cls.__annotations__.values())
|
||||
|
||||
|
||||
def is_uint(typ):
|
||||
# Note: only the type reference exists,
|
||||
# but it really resolves to 'int' during run-time for zero computational/memory overhead.
|
||||
# Hence, we check equality to the type references (which are really just 'NewType' instances),
|
||||
# and don't use any sub-classing like we normally would.
|
||||
return typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64 \
|
||||
or typ == uint128 or typ == uint256 or typ == byte
|
||||
|
||||
# SSZ vector
|
||||
# -----------------------------
|
||||
|
||||
|
||||
def _is_vector_instance_of(a, b):
|
||||
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector (b) is not an instance of Vector[X, Y] (a)
|
||||
return False
|
||||
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
|
||||
# Vector[X, Y] (b) is an instance of Vector (a)
|
||||
return True
|
||||
|
||||
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
|
||||
return a.elem_type == b.elem_type and a.length == b.length
|
||||
|
||||
|
||||
def _is_equal_vector_type(a, b):
|
||||
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
|
||||
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector == Vector
|
||||
return True
|
||||
# Vector != Vector[X, Y]
|
||||
return False
|
||||
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
|
||||
# Vector[X, Y] != Vector
|
||||
return False
|
||||
# Vector[X, Y] == Vector[X, Y]
|
||||
return a.elem_type == b.elem_type and a.length == b.length
|
||||
|
||||
|
||||
class VectorMeta(type):
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
if 'elem_type' in attrs and 'length' in attrs:
|
||||
setattr(out, 'elem_type', attrs['elem_type'])
|
||||
setattr(out, 'length', attrs['length'])
|
||||
return out
|
||||
|
||||
def __getitem__(self, params):
|
||||
return self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
|
||||
|
||||
def __subclasscheck__(self, sub):
|
||||
return _is_vector_instance_of(self, sub)
|
||||
|
||||
def __instancecheck__(self, other):
|
||||
return _is_vector_instance_of(self, other.__class__)
|
||||
|
||||
def __eq__(self, other):
|
||||
return _is_equal_vector_type(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not _is_equal_vector_type(self, other)
|
||||
|
||||
|
||||
class Vector(metaclass=VectorMeta):
|
||||
|
||||
def __init__(self, *args: Iterable[T]):
|
||||
|
||||
cls = self.__class__
|
||||
if not hasattr(cls, 'elem_type'):
|
||||
raise TypeError("Type Vector without elem_type data cannot be instantiated")
|
||||
if not hasattr(cls, 'length'):
|
||||
raise TypeError("Type Vector without length data cannot be instantiated")
|
||||
|
||||
if len(args) != cls.length:
|
||||
if len(args) == 0:
|
||||
from .ssz_impl import get_zero_value
|
||||
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
|
||||
else:
|
||||
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
|
||||
|
||||
self.items = list(args)
|
||||
|
||||
# cannot check non-class objects
|
||||
if isclass(cls.elem_type):
|
||||
for i, item in enumerate(self.items):
|
||||
if not isinstance(item, cls.elem_type):
|
||||
raise TypeError("Typed vector cannot hold differently typed value"
|
||||
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
return serialize(self, self.__class__)
|
||||
|
||||
def hash_tree_root(self):
|
||||
from .ssz_impl import hash_tree_root
|
||||
return hash_tree_root(self, self.__class__)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.items[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.items[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.items)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
|
||||
def _is_bytes_n_instance_of(a, b):
|
||||
if not hasattr(b, 'length'):
|
||||
# BytesN (b) is not an instance of BytesN[X] (a)
|
||||
return False
|
||||
if not hasattr(a, 'length'):
|
||||
# BytesN[X] (b) is an instance of BytesN (a)
|
||||
return True
|
||||
|
||||
# BytesN[X] (a) is an instance of BytesN[X] (b)
|
||||
return a.length == b.length
|
||||
|
||||
|
||||
def _is_equal_bytes_n_type(a, b):
|
||||
if not hasattr(a, 'length'):
|
||||
if not hasattr(b, 'length'):
|
||||
# BytesN == BytesN
|
||||
return True
|
||||
# BytesN != BytesN[X]
|
||||
return False
|
||||
if not hasattr(b, 'length'):
|
||||
# BytesN[X] != BytesN
|
||||
return False
|
||||
# BytesN[X] == BytesN[X]
|
||||
return a.length == b.length
|
||||
|
||||
|
||||
class BytesNMeta(type):
|
||||
def __new__(cls, class_name, parents, attrs):
|
||||
out = type.__new__(cls, class_name, parents, attrs)
|
||||
if 'length' in attrs:
|
||||
setattr(out, 'length', attrs['length'])
|
||||
return out
|
||||
|
||||
def __getitem__(self, n):
|
||||
return self.__class__(self.__name__, (BytesN,), {'length': n})
|
||||
|
||||
def __subclasscheck__(self, sub):
|
||||
return _is_bytes_n_instance_of(self, sub)
|
||||
|
||||
def __instancecheck__(self, other):
|
||||
return _is_bytes_n_instance_of(self, other.__class__)
|
||||
|
||||
def __eq__(self, other):
|
||||
return _is_equal_bytes_n_type(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not _is_equal_bytes_n_type(self, other)
|
||||
|
||||
|
||||
def parse_bytes(val):
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, str):
|
||||
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
if isinstance(val, int):
|
||||
return bytes([val])
|
||||
return None
|
||||
|
||||
|
||||
class BytesN(bytes, metaclass=BytesNMeta):
|
||||
def __new__(cls, *args):
|
||||
if not hasattr(cls, 'length'):
|
||||
return
|
||||
bytesval = None
|
||||
if len(args) == 1:
|
||||
val: Union[bytes, int, str] = args[0]
|
||||
bytesval = parse_bytes(val)
|
||||
elif len(args) > 1:
|
||||
# TODO: each int is 1 byte, check size, create bytesval
|
||||
bytesval = bytes(args)
|
||||
|
||||
if bytesval is None:
|
||||
if cls.length == 0:
|
||||
bytesval = b''
|
||||
else:
|
||||
bytesval = b'\x00' * cls.length
|
||||
if len(bytesval) != cls.length:
|
||||
raise TypeError("bytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
|
||||
return super().__new__(cls, bytesval)
|
||||
|
||||
def serialize(self):
|
||||
from .ssz_impl import serialize
|
||||
return serialize(self, self.__class__)
|
||||
|
||||
def hash_tree_root(self):
|
||||
from .ssz_impl import hash_tree_root
|
||||
return hash_tree_root(self, self.__class__)
|
||||
|
||||
Reference in New Issue
Block a user