update SSZ implementation

This commit is contained in:
protolambda
2019-05-25 00:05:03 +02:00
parent 08faa86d70
commit ed4416ba34
3 changed files with 291 additions and 192 deletions

View File

@@ -1,5 +1,27 @@
from eth2spec.utils.merkle_minimal import merkleize_chunks
from .ssz_switch import *
from .ssz_typing import *
# SSZ Defaults
# -----------------------------
def get_zero_value(typ):
if is_uint(typ):
return 0
if issubclass(typ, bool):
return False
if issubclass(typ, list):
return []
if issubclass(typ, Vector):
return typ()
if issubclass(typ, BytesN):
return typ()
if issubclass(typ, bytes):
return b''
if issubclass(typ, SSZContainer):
return typ(**{f: get_zero_value(t) for f, t in typ.get_fields().items()}),
# SSZ Helpers
# -----------------------------
@@ -14,23 +36,10 @@ def chunkify(byte_string):
return [byte_string[i:i + 32] for i in range(0, len(byte_string), 32)]
BYTES_PER_LENGTH_OFFSET = 4
# SSZ Implementation
# SSZ Serialization
# -----------------------------
get_zero_value = ssz_type_switch({
ssz_bool: lambda: False,
ssz_uint: lambda: 0,
ssz_list: lambda byte_form: b'' if byte_form else [],
ssz_vector: lambda length, elem_typ, byte_form:
(b'\x00' * length if length > 0 else b'') if byte_form else
[get_zero_value(elem_typ) for _ in range(length)],
ssz_container: lambda typ, field_names, field_types:
typ(**{f_name: get_zero_value(f_typ) for f_name, f_typ in zip(field_names, field_types)}),
})
BYTES_PER_LENGTH_OFFSET = 4
serialize = ssz_switch({
ssz_bool: lambda value: b'\x01' if value else b'\x00',
@@ -40,13 +49,6 @@ serialize = ssz_switch({
ssz_container: lambda value, get_field_values, field_types: encode_series(get_field_values(value), field_types),
})
ssz_basic_type = (ssz_bool, ssz_uint)
is_basic_type = ssz_type_switch({
ssz_basic_type: lambda: True,
ssz_default: lambda: False,
})
is_fixed_size = ssz_type_switch({
ssz_basic_type: lambda: True,
ssz_vector: lambda elem_typ: is_fixed_size(elem_typ),
@@ -55,6 +57,27 @@ is_fixed_size = ssz_type_switch({
})
# SSZ Hash-tree-root
# -----------------------------
def serialize_basic(value, typ):
if is_uint(typ):
return value.to_bytes(typ.byte_len, 'little')
if issubclass(typ, bool):
if value:
return b'\x01'
else:
return b'\x00'
def pack(values, subtype):
return b''.join([serialize_basic(value, subtype) for value in values])
def is_basic_type(typ):
return is_uint(typ) or issubclass(typ, bool)
def hash_tree_root_list(value, elem_typ):
if is_basic_type(elem_typ):
return merkleize_chunks(chunkify(pack(value, elem_typ)))
@@ -77,10 +100,9 @@ hash_tree_root = ssz_switch({
ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)),
})
signing_root = ssz_switch({
ssz_container: lambda value, get_field_values, field_types: hash_tree_root_container(zip(get_field_values(value), field_types)[:-1]),
ssz_default: lambda value, typ: hash_tree_root(value, typ),
})
# todo: signing root
def signing_root(value, typ):
pass
def encode_series(values, types):
@@ -115,3 +137,21 @@ def encode_series(values, types):
# Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
return b''.join(fixed_parts + variable_parts)
# Implementation notes:
# - SSZContainer,Vector/BytesN.hash_tree_root/serialize functions are for ease, implementation here
# - uint types have a 'byte_len' attribute
# - uint types are not classes. They use NewType(), for performance.
# This forces us to check type equivalence by exact reference.
# There's no class. The type data comes from an annotation/argument from the context of the value.
# - Vector is not valid to create instances with. Give it a elem-type and length: Vector[FooBar, 123]
# - *The class of* a Vector instance has a `elem_type` (type, may not be a class, see uint) and `length` (int)
# - BytesN is not valid to create instances with. Give it a length: BytesN[123]
# - *The class of* a BytesN instance has a `length` (int)
# Where possible, it is preferable to create helpers that just act on the type, and don't unnecessarily use a value
# E.g. is_basic_type(). This way, we can use them in type-only contexts and have no duplicate logic.
# For every class-instance, you can get the type with my_object.__class__
# For uints, and other NewType related, you have to rely on type information. It cannot be retrieved from the value.
# Note: we may just want to box integers instead. And then we can do bounds checking too. But it is SLOW and MEMORY INTENSIVE.
#

View File

@@ -1,105 +0,0 @@
from typing import Dict, Any
from .ssz_typing import *
# SSZ Switch statement runner factory
# -----------------------------
def ssz_switch(sw: Dict[Any, Any], arg_names=None):
"""
Creates an SSZ switch statement: a function, that when executed, checks every switch-statement
"""
if arg_names is None:
arg_names = ["value", "typ"]
# Runner, the function that executes the switch when called.
# Accepts a arguments based on the arg_names declared in the ssz_switch.
def run_switch(*args):
# value may be None
value = None
try:
value = args[arg_names.index("value")]
except ValueError:
pass # no value argument
# typ may be None when value is not None
typ = None
try:
typ = args[arg_names.index("typ")]
except ValueError:
# no typ argument expected
pass
except IndexError:
# typ argument expected, but not passed. Try to get it from the class info
typ = value.__class__
if hasattr(typ, '__forward_arg__'):
typ = typ.__forward_arg__
# Now, go over all switch cases
for matchers, worker in sw.items():
if not isinstance(matchers, tuple):
matchers = (matchers,)
# for each matcher of the case key
for m in matchers:
data = m(typ)
# if we have data, the matcher matched, and we can return the result
if data is not None:
# Supply value and type by default, and any data presented by the matcher.
kwargs = {"value": value, "typ": typ, **data}
# Filter out unwanted arguments
filtered_kwargs = {k: kwargs[k] for k in worker.__code__.co_varnames}
# run the switch case and return result
return worker(**filtered_kwargs)
raise Exception("cannot find matcher for type: %s (value: %s)" % (typ, value))
return run_switch
def ssz_type_switch(sw: Dict[Any, Any]):
return ssz_switch(sw, ["typ"])
# SSZ Switch matchers
# -----------------------------
def ssz_bool(typ):
if typ == bool:
return {}
def ssz_uint(typ):
# Note: only the type reference exists,
# but it really resolves to 'int' during run-time for zero computational/memory overhead.
# Hence, we check equality to the type references (which are really just 'NewType' instances),
# and don't use any sub-classing like we normally would.
if typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64\
or typ == uint128 or typ == uint256 or typ == byte:
return {"byte_len": typ.byte_len}
def ssz_list(typ):
if hasattr(typ, '__bases__') and List in typ.__bases__:
return {"elem_typ": read_list_elem_typ(typ), "byte_form": False}
if typ == bytes:
return {"elem_typ": uint8, "byte_form": True}
def ssz_vector(typ):
if hasattr(typ, '__bases__'):
if Vector in typ.__bases__:
return {"elem_typ": read_vec_elem_typ(typ), "length": read_vec_len(typ), "byte_form": False}
if BytesN in typ.__bases__:
return {"elem_typ": uint8, "length": read_bytesN_len(typ), "byte_form": True}
def ssz_container(typ):
if hasattr(typ, '__bases__') and SSZContainer in typ.__bases__:
def get_field_values(value):
return [getattr(value, field) for field in typ.__annotations__.keys()]
field_names = list(typ.__annotations__.keys())
field_types = list(typ.__annotations__.values())
return {"get_field_values": get_field_values, "field_names": field_names, "field_types": field_types}
def ssz_default(typ):
return {}

View File

@@ -1,72 +1,18 @@
from typing import Generic, List, TypeVar, Type, Iterable, NewType
from typing import List, Iterable, TypeVar, Type, NewType
from typing import Union
from inspect import isclass
# SSZ base length, to limit length generic type param of vector/bytesN
SSZLenAny = type('SSZLenAny', (), {})
def SSZLen(length: int):
"""
SSZ length factory. Creates a type corresponding to a given length. To be used as parameter in type generics.
"""
assert length >= 0
typ = type('SSZLen_%d' % length, (SSZLenAny,), {})
typ.length = length
return typ
# SSZ element type
T = TypeVar('T')
# SSZ vector/bytesN length
L = TypeVar('L', bound=SSZLenAny)
# SSZ vector
# -----------------------------
class Vector(Generic[T, L]):
def __init__(self, *args: Iterable[T]):
self.items = list(args)
def __getitem__(self, key):
return self.items[key]
def __setitem__(self, key, value):
self.items[key] = value
def __iter__(self):
return iter(self.items)
def __len__(self):
return len(self.items)
def read_vec_elem_typ(vec_typ: Type[Vector[T,L]]) -> T:
assert vec_typ.__args__ is not None
return vec_typ.__args__[0]
def read_vec_len(vec_typ: Type[Vector[T,L]]) -> int:
assert vec_typ.__args__ is not None
return vec_typ.__args__[1].length
# SSZ list
# -----------------------------
def read_list_elem_typ(list_typ: Type[List[T]]) -> T:
assert list_typ.__args__ is not None
return list_typ.__args__[0]
# SSZ bytesN
# -----------------------------
class BytesN(Generic[L]):
pass
def read_bytesN_len(bytesN_typ: Type[BytesN[L]]) -> int:
assert bytesN_typ.__args__ is not None
return bytesN_typ.__args__[0].length
# SSZ integer types, with 0 computational overhead (NewType)
# -----------------------------
@@ -94,8 +40,9 @@ byte = NewType('byte', uint8)
class SSZContainer(object):
def __init__(self, **kwargs):
cls = self.__class__
from .ssz_impl import get_zero_value
for f, t in self.__annotations__.items():
for f, t in cls.get_fields().items():
if f not in kwargs:
setattr(self, f, get_zero_value(t))
else:
@@ -113,3 +60,220 @@ class SSZContainer(object):
from .ssz_impl import signing_root
return signing_root(self, self.__class__)
def get_field_values(self):
cls = self.__class__
return [getattr(self, field) for field in cls.get_field_names()]
@classmethod
def get_fields(cls):
return dict(cls.__annotations__)
@classmethod
def get_field_names(cls):
return list(cls.__annotations__.keys())
@classmethod
def get_field_types(cls):
# values of annotations are the types corresponding to the fields, not instance values.
return list(cls.__annotations__.values())
def is_uint(typ):
# Note: only the type reference exists,
# but it really resolves to 'int' during run-time for zero computational/memory overhead.
# Hence, we check equality to the type references (which are really just 'NewType' instances),
# and don't use any sub-classing like we normally would.
return typ == uint8 or typ == uint16 or typ == uint32 or typ == uint64 \
or typ == uint128 or typ == uint256 or typ == byte
# SSZ vector
# -----------------------------
def _is_vector_instance_of(a, b):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector (b) is not an instance of Vector[X, Y] (a)
return False
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
# Vector[X, Y] (b) is an instance of Vector (a)
return True
# Vector[X, Y] (a) is an instance of Vector[X, Y] (b)
return a.elem_type == b.elem_type and a.length == b.length
def _is_equal_vector_type(a, b):
if not hasattr(a, 'elem_type') or not hasattr(a, 'length'):
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector == Vector
return True
# Vector != Vector[X, Y]
return False
if not hasattr(b, 'elem_type') or not hasattr(b, 'length'):
# Vector[X, Y] != Vector
return False
# Vector[X, Y] == Vector[X, Y]
return a.elem_type == b.elem_type and a.length == b.length
class VectorMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'elem_type' in attrs and 'length' in attrs:
setattr(out, 'elem_type', attrs['elem_type'])
setattr(out, 'length', attrs['length'])
return out
def __getitem__(self, params):
return self.__class__(self.__name__, (Vector,), {'elem_type': params[0], 'length': params[1]})
def __subclasscheck__(self, sub):
return _is_vector_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_vector_instance_of(self, other.__class__)
def __eq__(self, other):
return _is_equal_vector_type(self, other)
def __ne__(self, other):
return not _is_equal_vector_type(self, other)
class Vector(metaclass=VectorMeta):
def __init__(self, *args: Iterable[T]):
cls = self.__class__
if not hasattr(cls, 'elem_type'):
raise TypeError("Type Vector without elem_type data cannot be instantiated")
if not hasattr(cls, 'length'):
raise TypeError("Type Vector without length data cannot be instantiated")
if len(args) != cls.length:
if len(args) == 0:
from .ssz_impl import get_zero_value
args = [get_zero_value(cls.elem_type) for _ in range(cls.length)]
else:
raise TypeError("Typed vector with length %d cannot hold %d items" % (cls.length, len(args)))
self.items = list(args)
# cannot check non-class objects
if isclass(cls.elem_type):
for i, item in enumerate(self.items):
if not isinstance(item, cls.elem_type):
raise TypeError("Typed vector cannot hold differently typed value"
" at index %d. Got type: %s, expected type: %s" % (i, type(item), cls.elem_type))
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)
def __getitem__(self, key):
return self.items[key]
def __setitem__(self, key, value):
self.items[key] = value
def __iter__(self):
return iter(self.items)
def __len__(self):
return len(self.items)
def _is_bytes_n_instance_of(a, b):
if not hasattr(b, 'length'):
# BytesN (b) is not an instance of BytesN[X] (a)
return False
if not hasattr(a, 'length'):
# BytesN[X] (b) is an instance of BytesN (a)
return True
# BytesN[X] (a) is an instance of BytesN[X] (b)
return a.length == b.length
def _is_equal_bytes_n_type(a, b):
if not hasattr(a, 'length'):
if not hasattr(b, 'length'):
# BytesN == BytesN
return True
# BytesN != BytesN[X]
return False
if not hasattr(b, 'length'):
# BytesN[X] != BytesN
return False
# BytesN[X] == BytesN[X]
return a.length == b.length
class BytesNMeta(type):
def __new__(cls, class_name, parents, attrs):
out = type.__new__(cls, class_name, parents, attrs)
if 'length' in attrs:
setattr(out, 'length', attrs['length'])
return out
def __getitem__(self, n):
return self.__class__(self.__name__, (BytesN,), {'length': n})
def __subclasscheck__(self, sub):
return _is_bytes_n_instance_of(self, sub)
def __instancecheck__(self, other):
return _is_bytes_n_instance_of(self, other.__class__)
def __eq__(self, other):
return _is_equal_bytes_n_type(self, other)
def __ne__(self, other):
return not _is_equal_bytes_n_type(self, other)
def parse_bytes(val):
if val is None:
return None
if isinstance(val, str):
# TODO: import from eth-utils instead, and do: hexstr_if_str(to_bytes, val)
return None
if isinstance(val, bytes):
return val
if isinstance(val, int):
return bytes([val])
return None
class BytesN(bytes, metaclass=BytesNMeta):
def __new__(cls, *args):
if not hasattr(cls, 'length'):
return
bytesval = None
if len(args) == 1:
val: Union[bytes, int, str] = args[0]
bytesval = parse_bytes(val)
elif len(args) > 1:
# TODO: each int is 1 byte, check size, create bytesval
bytesval = bytes(args)
if bytesval is None:
if cls.length == 0:
bytesval = b''
else:
bytesval = b'\x00' * cls.length
if len(bytesval) != cls.length:
raise TypeError("bytesN[%d] cannot be initialized with value of %d bytes" % (cls.length, len(bytesval)))
return super().__new__(cls, bytesval)
def serialize(self):
from .ssz_impl import serialize
return serialize(self, self.__class__)
def hash_tree_root(self):
from .ssz_impl import hash_tree_root
return hash_tree_root(self, self.__class__)