Modifications from Vitalik, to enable SSZ Partials

This commit is contained in:
protolambda
2019-06-20 20:20:07 +02:00
parent 6f46c1d837
commit 6b82e3faa5
2 changed files with 24 additions and 32 deletions

View File

@@ -107,7 +107,7 @@ def is_bottom_layer_kind(typ: SSZType):
def item_length(typ: SSZType) -> int:
if issubclass(typ, BasicType):
if issubclass(typ, BasicValue):
return typ.byte_len
else:
return 32

View File

@@ -31,6 +31,7 @@ class BasicValue(int, SSZValue, metaclass=BasicType):
class Bit(BasicValue): # can't subclass bool.
byte_len = 1
def __new__(cls, value, *args, **kwargs):
if value < 0 or value > 1:
@@ -88,7 +89,7 @@ class uint256(uint):
byte_len = 32
def coerce_type_maybe(v, typ: SSZType):
def coerce_type_maybe(v, typ: SSZType, strict: bool = False):
v_typ = type(v)
# shortcut if it's already the type we are looking for
if v_typ == typ:
@@ -97,10 +98,14 @@ def coerce_type_maybe(v, typ: SSZType):
return typ(v)
elif isinstance(v, (list, tuple)):
return typ(*v)
elif isinstance(v, bytes):
return typ(v)
elif isinstance(v, GeneratorType):
return typ(v)
else:
# just return as-is, Value-checkers will take care of it not being coerced.
if strict and not isinstance(v, typ):
raise ValueError("Type coercion of {} to {} failed".format(v, typ))
return v
@@ -116,7 +121,7 @@ class Container(Series, metaclass=SSZType):
def __init__(self, **kwargs):
cls = self.__class__
for f, t in cls.get_fields():
for f, t in cls.get_fields().items():
if f not in kwargs:
setattr(self, f, t.default())
else:
@@ -148,16 +153,12 @@ class Container(Series, metaclass=SSZType):
f" field: {name} type: {field_typ} value: {value} value type: {type(value)}")
super().__setattr__(name, value)
def get_field_values(self) -> Tuple[SSZValue, ...]:
cls = self.__class__
return tuple(getattr(self, field) for field in cls.get_field_names())
def __repr__(self):
return repr({field: getattr(self, field) for field in self.get_field_names()})
return repr({field: getattr(self, field) for field in self.get_fields()})
def __str__(self):
output = [f'{self.__class__.__name__}']
for field in self.get_field_names():
for field in self.get_fields():
output.append(f' {field}: {getattr(self, field)}')
return "\n".join(output)
@@ -168,23 +169,10 @@ class Container(Series, metaclass=SSZType):
return hash(self.hash_tree_root())
@classmethod
def get_fields(cls) -> Tuple[Tuple[str, SSZType], ...]:
def get_fields(cls) -> Dict[str, SSZType]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return tuple((f, t) for f, t in cls.__annotations__.items())
@classmethod
def get_field_names(cls) -> Tuple[str, ...]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
return tuple(cls.__annotations__.keys())
@classmethod
def get_field_types(cls) -> Tuple[SSZType, ...]:
if not hasattr(cls, '__annotations__'): # no container fields
return ()
# values of annotations are the types corresponding to the fields, not instance values.
return tuple(cls.__annotations__.values())
return {}
return dict(cls.__annotations__)
@classmethod
def default(cls):
@@ -195,7 +183,7 @@ class Container(Series, metaclass=SSZType):
return all(t.is_fixed_size() for t in cls.get_field_types())
def __iter__(self) -> Iterator[SSZValue]:
return iter(self.get_field_values())
return iter([getattr(self, field) for field in self.get_fields()])
class ParamsBase(Series):
@@ -297,12 +285,16 @@ class Elements(ParamsBase, metaclass=ElementsType):
if k > len(self.items):
raise IndexError(f"cannot set item in type {self.__class__}"
f" at out of bounds index {k} (to {v}, bound: {len(self.items)})")
typ = self.__class__.elem_type
v = coerce_type_maybe(v, typ)
if not isinstance(v, typ):
raise ValueError(f"Cannot set item in type {self.__class__},"
f" mismatched element type: {v} of {type(v)}, expected {typ}")
self.items[k] = v
self.items[k] = coerce_type_maybe(v, self.__class__.elem_type, strict=True)
def append(self, v):
self.items.append(coerce_type_maybe(v, self.__class__.elem_type, strict=True))
def pop(self):
if len(self.items) == 0:
raise IndexError("Pop from empty list")
else:
return self.items.pop()
def __len__(self):
return len(self.items)