diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py index 21dba1034..0ef89f97f 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_impl.py @@ -14,11 +14,22 @@ def is_basic_type(typ): def serialize_basic(value, typ): if is_uint_type(typ): return value.to_bytes(uint_byte_size(typ), 'little') - if is_bool_type(typ): + elif is_bool_type(typ): if value: return b'\x01' else: return b'\x00' + else: + raise Exception("Type not supported: {}".format(typ)) + +def deserialize_basic(value, typ): + if is_uint_type(typ): + return typ(int.from_bytes(value, 'little')) + elif is_bool_type(typ): + assert value in (b'\x00', b'\x01') + return True if value == b'\x01' else False + else: + raise Exception("Type not supported: {}".format(typ)) def is_fixed_size(typ): @@ -112,7 +123,7 @@ def get_typed_values(obj, typ=None): return obj.get_typed_values() elif is_list_kind(typ) or is_vector_kind(typ): elem_type = read_elem_type(typ) - return zip(obj, [elem_type] * len(obj)) + return list(zip(obj, [elem_type] * len(obj))) else: raise Exception("Invalid type") diff --git a/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py b/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py index b30ebfb4c..6a243f4f3 100644 --- a/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py +++ b/test_libs/pyspec/eth2spec/utils/ssz/ssz_partials.py @@ -36,7 +36,7 @@ def filler(starting_position, chunk_count): while at % (skip*2) == 0: skip *= 2 value = hash(value + value) - o[starting_position + at] = value + o[(starting_position + at) // skip] = value at += skip return o @@ -47,19 +47,211 @@ def merkle_tree_of_chunks(chunks, root): return o @infer_input_type -def ssz_all(obj, typ=None, root=1): - if is_list_type(typ): +def ssz_leaves(obj, typ=None, root=1): + if is_list_kind(typ): o = {root * 2 + 1: len(obj).to_bytes(32, 'little')} base = root * 2 else: o = {} base = root - if is_bottom_layer_type(typ): - data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_typ(typ)) - return {**o, **merkle_tree_of_chunks(chunkify(data), base)} + if is_bottom_layer_kind(typ): + data = serialize_basic(obj, typ) if is_basic_type(typ) else pack(obj, read_elem_type(typ)) + q = {**o, **merkle_tree_of_chunks(chunkify(data), base)} + #print(obj, root, typ, base, list(q.keys())) + return(q) else: fields = get_typed_values(obj, typ=typ) sub_base = base * next_power_of_two(len(fields)) for i, (elem, elem_type) in enumerate(fields): - o = {**o, **ssz_all(elem, typ=elem_type, root=sub_base+i)} - return {**o, **filter(sub_base, len(fields))} + o = {**o, **ssz_leaves(elem, typ=elem_type, root=sub_base+i)} + q = {**o, **filler(sub_base, len(fields))} + #print(obj, root, typ, base, list(q.keys())) + return(q) + +def fill(objects): + objects = {k:v for k,v in objects.items()} + keys = sorted(objects.keys())[::-1] + pos = 0 + while pos < len(keys): + k = keys[pos] + if k in objects and k^1 in objects and k//2 not in objects: + objects[k//2] = hash(objects[k&-2] + objects[k|1]) + keys.append(k // 2) + pos += 1 + return objects + +@infer_input_type +def ssz_full(obj, typ=None): + return fill(ssz_leaves(obj, typ=typ)) + +def get_basic_type_size(typ): + if is_uint_type(typ): + return uint_byte_size(typ) + elif is_bool_type(typ): + return 1 + else: + raise Exception("Type not basic: {}".format(typ)) + +def get_bottom_layer_element_position(typ, base, length, index): + """ + Returns the generalized index and the byte range of the index'th value + in the list with the given base generalized index and given length + """ + assert index < (1 if is_basic_type(typ) else length) + elem_typ = typ if is_basic_type(typ) else read_elem_type(typ) + elem_size = get_basic_type_size(elem_typ) + chunk_index = index * elem_size // 32 + chunk_count = (1 if is_basic_type(typ) else length) * elem_size // 32 + generalized_index = base * next_power_of_two(chunk_count) + chunk_index + start = elem_size * index % 32 + return generalized_index, start, start+elem_size + +@infer_input_type +def get_generalized_indices(obj, path, typ=None, root=1): + if len(path) == 0: + return [root] if is_basic_type(typ) else list(ssz_leaves(obj, typ=typ, root=root).keys()) + if path[0] == '__len__': + return [root * 2 + 1] if is_list_type(typ) else [] + base = root * 2 if is_list_kind(typ) else root + if is_bottom_layer_kind(typ): + length = 1 if is_basic_type(typ) else len(obj) + index, _, _ = get_bottom_layer_element_position(typ, base, length, path[0]) + return [index] + else: + if is_container_type(typ): + fields = typ.get_field_names() + field_count, index = len(fields), fields.index(path[0]) + elem_type = typ.get_field_types()[index] + child = obj.get_field_values()[index] + else: + field_count, index, elem_type, child = len(obj), path[0], read_elem_type(typ), obj[path[0]] + return get_generalized_indices( + child, + path[1:], + typ=elem_type, + root=base * next_power_of_two(field_count) + index + ) + +def get_branch_indices(tree_index): + o = [tree_index, tree_index ^ 1] + while o[-1] > 1: + o.append((o[-1] // 2) ^ 1) + return o[:-1] + +def remove_redundant_indices(obj): + return {k:v for k,v in obj.items() if not (k*2 in obj and k*2+1 in obj)} + +def merge(*args): + o = {} + for arg in args: + o = {**o, **arg} + return fill(o) + +@infer_input_type +def get_nodes_along_path(obj, path, typ=None): + indices = get_generalized_indices(obj, path, typ=typ) + return remove_redundant_indices(merge(*({i:obj.objects[i] for i in get_branch_indices(index)} for index in indices))) + +class OutOfRangeException(Exception): + pass + +class SSZPartial(): + def __init__(self, typ, objects, root=1): + assert not is_basic_type(typ) + self.objects = objects + self.typ = typ + self.root = root + if is_container_type(self.typ): + for field in self.typ.get_field_names(): + try: + setattr(self, field, self.getter(field)) + except OutOfRangeException: + pass + + def getter(self, index): + base = self.root * 2 if is_list_kind(self.typ) else self.root + if is_bottom_layer_kind(self.typ): + tree_index, start, end = get_bottom_layer_element_position( + self.typ, base, len(self), index + ) + if tree_index not in self.objects: + raise OutOfRangeException("Do not have required data") + else: + return deserialize_basic( + self.objects[tree_index][start:end], + self.typ if is_basic_type(self.typ) else read_elem_type(self.typ) + ) + else: + if is_container_type(self.typ): + fields = self.typ.get_field_names() + field_count, index = len(fields), fields.index(index) + elem_type = self.typ.get_field_types()[index] + else: + field_count, index, elem_type = len(self), index, read_elem_type(self.typ) + tree_index = base * next_power_of_two(field_count) + index + if tree_index not in self.objects: + raise OutOfRangeException("Do not have required data") + if is_basic_type(elem_type): + return deserialize_basic(self.objects[tree_index][:get_basic_type_size(elem_type)], elem_type) + else: + return ssz_partial(elem_type, self.objects, root=tree_index) + + def __getitem__(self, index): + return self.getter(index) + + def __iter__(self): + return (self[i] for i in range(len(self))) + + def __len__(self): + if is_list_kind(self.typ): + if self.root*2+1 not in self.objects: + raise OutOfRangeException("Do not have required data: {}".format(self.root*2+1)) + return int.from_bytes(self.objects[self.root*2+1], 'little') + elif is_vector_kind(self.typ): + return self.typ.length + elif is_container_type(self.typ): + return len(self.typ.get_fields()) + else: + raise Exception("Unsupported type: {}".format(self.typ)) + + def full_value(self): + if is_bytes_type(self.typ) or is_bytesn_type(self.typ): + return bytes([self.getter(i) for i in range(len(self))]) + elif is_list_kind(self.typ): + return [self[i] for i in range(len(self))] + elif is_vector_kind(self.typ): + return self.typ(*(self[i] for i in range(len(self)))) + elif is_container_type(self.typ): + full_value = lambda x: x.full_value() if hasattr(x, 'full_value') else x + return self.typ(**{field: full_value(self.getter(field)) for field in self.typ.get_field_names()}) + elif is_basic_type(self.typ): + return self.getter(0) + else: + raise Exception("Unsupported type: {}".format(self.typ)) + + def hash_tree_root(self): + o = {**self.objects} + keys = sorted(o.keys())[::-1] + pos = 0 + while pos < len(keys): + k = keys[pos] + if k in o and k^1 in o and k//2 not in o: + o[k//2] = hash(o[k&-2] + o[k|1]) + keys.append(k // 2) + pos += 1 + return o[self.root] + + def __str__(self): + return str(self.full_value()) + +def ssz_partial(typ, objects, root=1): + ssz_type = ( + Container if is_container_type(typ) else + typ if (is_vector_type(typ) or is_bytesn_type(typ)) else object + ) + class Partial(SSZPartial, ssz_type): + pass + if is_container_type(typ): + Partial.__annotations__ = typ.__annotations__ + o = Partial(typ, objects, root=root) + return o