From 6b82e3faa56707b0c1fd11cd82e2f8eb36037370 Mon Sep 17 00:00:00 2001 From: protolambda Date: Thu, 20 Jun 2019 20:20:07 +0200 Subject: [PATCH] Modifications from Vitalik, to enable SSZ Partials --- .../pyspec/eth2spec/utils/ssz/ssz_impl.py | 2 +- .../pyspec/eth2spec/utils/ssz/ssz_typing.py | 54 ++++++++----------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index fd17e29f9..a9c36649b 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -107,7 +107,7 @@ def is_bottom_layer_kind(typ: SSZType): def item_length(typ: SSZType) -> int: - if issubclass(typ, BasicType): + if issubclass(typ, BasicValue): return typ.byte_len else: return 32 diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py index 082e3ed30..981f30d9b 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_typing.py @@ -31,6 +31,7 @@ class BasicValue(int, SSZValue, metaclass=BasicType): class Bit(BasicValue): # can't subclass bool. + byte_len = 1 def __new__(cls, value, *args, **kwargs): if value < 0 or value > 1: @@ -88,7 +89,7 @@ class uint256(uint): byte_len = 32 -def coerce_type_maybe(v, typ: SSZType): +def coerce_type_maybe(v, typ: SSZType, strict: bool = False): v_typ = type(v) # shortcut if it's already the type we are looking for if v_typ == typ: @@ -97,10 +98,14 @@ def coerce_type_maybe(v, typ: SSZType): return typ(v) elif isinstance(v, (list, tuple)): return typ(*v) + elif isinstance(v, bytes): + return typ(v) elif isinstance(v, GeneratorType): return typ(v) else: # just return as-is, Value-checkers will take care of it not being coerced. + if strict and not isinstance(v, typ): + raise ValueError("Type coercion of {} to {} failed".format(v, typ)) return v @@ -116,7 +121,7 @@ class Container(Series, metaclass=SSZType): def __init__(self, **kwargs): cls = self.__class__ - for f, t in cls.get_fields(): + for f, t in cls.get_fields().items(): if f not in kwargs: setattr(self, f, t.default()) else: @@ -148,16 +153,12 @@ class Container(Series, metaclass=SSZType): f" field: {name} type: {field_typ} value: {value} value type: {type(value)}") super().__setattr__(name, value) - def get_field_values(self) -> Tuple[SSZValue, ...]: - cls = self.__class__ - return tuple(getattr(self, field) for field in cls.get_field_names()) - def __repr__(self): - return repr({field: getattr(self, field) for field in self.get_field_names()}) + return repr({field: getattr(self, field) for field in self.get_fields()}) def __str__(self): output = [f'{self.__class__.__name__}'] - for field in self.get_field_names(): + for field in self.get_fields(): output.append(f' {field}: {getattr(self, field)}') return "\n".join(output) @@ -168,23 +169,10 @@ class Container(Series, metaclass=SSZType): return hash(self.hash_tree_root()) @classmethod - def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]: + def get_fields(cls) -> Dict[str, SSZType]: if not hasattr(cls, '__annotations__'): # no container fields - return () - return tuple((f, t) for f, t in cls.__annotations__.items()) - - @classmethod - def get_field_names(cls) -> Tuple[str, ...]: - if not hasattr(cls, '__annotations__'): # no container fields - return () - return tuple(cls.__annotations__.keys()) - - @classmethod - def get_field_types(cls) -> Tuple[SSZType, ...]: - if not hasattr(cls, '__annotations__'): # no container fields - return () - # values of annotations are the types corresponding to the fields, not instance values. - return tuple(cls.__annotations__.values()) + return {} + return dict(cls.__annotations__) @classmethod def default(cls): @@ -195,7 +183,7 @@ class Container(Series, metaclass=SSZType): return all(t.is_fixed_size() for t in cls.get_field_types()) def __iter__(self) -> Iterator[SSZValue]: - return iter(self.get_field_values()) + return iter([getattr(self, field) for field in self.get_fields()]) class ParamsBase(Series): @@ -297,12 +285,16 @@ class Elements(ParamsBase, metaclass=ElementsType): if k > len(self.items): raise IndexError(f"cannot set item in type {self.__class__}" f" at out of bounds index {k} (to {v}, bound: {len(self.items)})") - typ = self.__class__.elem_type - v = coerce_type_maybe(v, typ) - if not isinstance(v, typ): - raise ValueError(f"Cannot set item in type {self.__class__}," - f" mismatched element type: {v} of {type(v)}, expected {typ}") - self.items[k] = v + self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True) + + def append(self, v): + self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True)) + + def pop(self): + if len(self.items) == 0: + raise IndexError("Pop from empty list") + else: + return self.items.pop() def __len__(self): return len(self.items)