Protocols pyspec support + execution payload tests cleanup

This commit is contained in:
protolambda
2021-05-12 02:40:23 +02:00
parent f58ba8f5b2
commit 0390ab819a
5 changed files with 91 additions and 17 deletions

View File

@@ -5,6 +5,7 @@ from distutils.util import convert_path
import os
import re
import string
import textwrap
from typing import Dict, NamedTuple, List, Sequence, Optional
from abc import ABC, abstractmethod
import ast
@@ -48,8 +49,14 @@ def floorlog2(x: int) -> uint64:
'''
class ProtocolDefinition(NamedTuple):
# just function definitions currently. May expand with configuration vars in future.
functions: Dict[str, str]
class SpecObject(NamedTuple):
functions: Dict[str, str]
protocols: Dict[str, ProtocolDefinition]
custom_types: Dict[str, str]
constants: Dict[str, str]
ssz_dep_constants: Dict[str, str] # the constants that depend on ssz_objects
@@ -73,6 +80,18 @@ def _get_function_name_from_source(source: str) -> str:
return fn.name
def _get_self_type_from_source(source: str) -> Optional[str]:
fn = ast.parse(source).body[0]
args = fn.args.args
if len(args) == 0:
return None
if args[0].arg != 'self':
return None
if args[0].annotation is None:
return None
return args[0].annotation.id
def _get_class_info_from_source(source: str) -> (str, Optional[str]):
class_def = ast.parse(source).body[0]
base = class_def.bases[0]
@@ -107,6 +126,7 @@ def _get_eth2_spec_comment(child: LinkRefDef) -> Optional[str]:
def get_spec(file_name: str) -> SpecObject:
functions: Dict[str, str] = {}
protocols: Dict[str, ProtocolDefinition] = {}
constants: Dict[str, str] = {}
ssz_dep_constants: Dict[str, str] = {}
ssz_objects: Dict[str, str] = {}
@@ -132,7 +152,14 @@ def get_spec(file_name: str) -> SpecObject:
source = _get_source_from_code_block(child)
if source.startswith("def"):
current_name = _get_function_name_from_source(source)
functions[current_name] = "\n".join(line.rstrip() for line in source.splitlines())
self_type_name = _get_self_type_from_source(source)
function_def = "\n".join(line.rstrip() for line in source.splitlines())
if self_type_name is None:
functions[current_name] = function_def
else:
if self_type_name not in protocols:
protocols[self_type_name] = ProtocolDefinition(functions={})
protocols[self_type_name].functions[current_name] = function_def
elif source.startswith("@dataclass"):
dataclasses[current_name] = "\n".join(line.rstrip() for line in source.splitlines())
elif source.startswith("class"):
@@ -170,6 +197,7 @@ def get_spec(file_name: str) -> SpecObject:
return SpecObject(
functions=functions,
protocols=protocols,
custom_types=custom_types,
constants=constants,
ssz_dep_constants=ssz_dep_constants,
@@ -422,7 +450,8 @@ class MergeSpecBuilder(Phase0SpecBuilder):
@classmethod
def imports(cls):
return super().imports() + '\n' + '''
return super().imports() + '''
from typing import Protocol
from eth2spec.phase0 import spec as phase0
from eth2spec.utils.ssz.ssz_typing import Bytes20, ByteList, ByteVector, uint256
from importlib import reload
@@ -451,13 +480,23 @@ def get_execution_state(execution_state_root: Bytes32) -> ExecutionState:
def get_pow_chain_head() -> PowBlock:
pass
verify_execution_state_transition_ret_value = True
def verify_execution_state_transition(execution_payload: ExecutionPayload) -> bool:
return verify_execution_state_transition_ret_value
class NoopExecutionEngine(ExecutionEngine):
def new_block(self, execution_payload: ExecutionPayload) -> bool:
return True
def set_head(self, block_hash: Hash32) -> bool:
return True
def finalize_block(self, block_hash: Hash32) -> bool:
return True
def assemble_block(self, block_hash: Hash32, timestamp: uint64) -> ExecutionPayload:
raise NotImplementedError("no default block production")
def produce_execution_payload(parent_hash: Hash32, timestamp: uint64) -> ExecutionPayload:
pass"""
EXECUTION_ENGINE = NoopExecutionEngine()"""
@classmethod
@@ -495,6 +534,15 @@ def objects_to_spec(spec_object: SpecObject, builder: SpecBuilder, ordered_class
]
)
)
def format_protocol(protocol_name: str, protocol_def: ProtocolDefinition) -> str:
protocol = f"class {protocol_name}(Protocol):"
for fn_source in protocol_def.functions.values():
fn_source = fn_source.replace("self: "+protocol_name, "self")
protocol += "\n\n" + textwrap.indent(fn_source, " ")
return protocol
protocols_spec = '\n\n\n'.join(format_protocol(k, v) for k, v in spec_object.protocols.items())
for k in list(spec_object.functions):
if "ceillog2" in k or "floorlog2" in k:
del spec_object.functions[k]
@@ -520,6 +568,7 @@ def objects_to_spec(spec_object: SpecObject, builder: SpecBuilder, ordered_class
+ '\n\n' + constants_spec
+ '\n\n' + CONFIG_LOADER
+ '\n\n' + ordered_class_objects_spec
+ ('\n\n\n' + protocols_spec if protocols_spec != '' else '')
+ '\n\n\n' + functions_spec
+ '\n\n' + builder.sundry_functions()
# Since some constants are hardcoded in setup.py, the following assertions verify that the hardcoded constants are
@@ -531,6 +580,17 @@ def objects_to_spec(spec_object: SpecObject, builder: SpecBuilder, ordered_class
return spec
def combine_protocols(old_protocols: Dict[str, ProtocolDefinition],
new_protocols: Dict[str, ProtocolDefinition]) -> Dict[str, ProtocolDefinition]:
for key, value in new_protocols.items():
if key not in old_protocols:
old_protocols[key] = value
else:
functions = combine_functions(old_protocols[key].functions, value.functions)
old_protocols[key] = ProtocolDefinition(functions=functions)
return old_protocols
def combine_functions(old_functions: Dict[str, str], new_functions: Dict[str, str]) -> Dict[str, str]:
for key, value in new_functions.items():
old_functions[key] = value
@@ -589,8 +649,9 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
"""
Takes in two spec variants (as tuples of their objects) and combines them using the appropriate combiner function.
"""
functions0, custom_types0, constants0, ssz_dep_constants0, ssz_objects0, dataclasses0 = spec0
functions1, custom_types1, constants1, ssz_dep_constants1, ssz_objects1, dataclasses1 = spec1
functions0, protocols0, custom_types0, constants0, ssz_dep_constants0, ssz_objects0, dataclasses0 = spec0
functions1, protocols1, custom_types1, constants1, ssz_dep_constants1, ssz_objects1, dataclasses1 = spec1
protocols = combine_protocols(protocols0, protocols1)
functions = combine_functions(functions0, functions1)
custom_types = combine_constants(custom_types0, custom_types1)
constants = combine_constants(constants0, constants1)
@@ -599,6 +660,7 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
dataclasses = combine_functions(dataclasses0, dataclasses1)
return SpecObject(
functions=functions,
protocols=protocols,
custom_types=custom_types,
constants=constants,
ssz_dep_constants=ssz_dep_constants,

View File

@@ -216,7 +216,9 @@ def process_block(state: BeaconState, block: BeaconBlock) -> None:
##### `process_execution_payload`
```python
def process_execution_payload(state: BeaconState, execution_payload: ExecutionPayload, execution_engine: ExecutionEngine) -> None:
def process_execution_payload(state: BeaconState,
execution_payload: ExecutionPayload,
execution_engine: ExecutionEngine) -> None:
"""
Note: This function is designed to be able to be run in parallel with the other `process_block` sub-functions
"""

View File

@@ -90,5 +90,5 @@ def get_execution_payload(state: BeaconState, execution_engine: ExecutionEngine)
# Post-merge, normal payload
execution_parent_hash = state.latest_execution_payload_header.block_hash
timestamp = compute_time_at_slot(state, state.slot)
return produce_execution_payload(execution_parent_hash, timestamp)
return execution_engine.assemble_block(execution_parent_hash, timestamp)
```

View File

@@ -24,6 +24,7 @@ def build_empty_execution_payload(spec, state):
return payload
def get_execution_payload_header(spec, execution_payload):
return spec.ExecutionPayloadHeader(
block_hash=execution_payload.block_hash,
@@ -39,15 +40,18 @@ def get_execution_payload_header(spec, execution_payload):
transactions_root=spec.hash_tree_root(execution_payload.transactions)
)
def build_state_with_incomplete_transition(spec, state):
return build_state_with_execution_payload_header(spec, state, spec.ExecutionPayloadHeader())
def build_state_with_complete_transition(spec, state):
pre_state_payload = build_empty_execution_payload(spec, state)
payload_header = get_execution_payload_header(spec, pre_state_payload)
return build_state_with_execution_payload_header(spec, state, payload_header)
def build_state_with_execution_payload_header(spec, state, execution_payload_header):
pre_state = state.copy()
pre_state.latest_execution_payload_header = execution_payload_header

View File

@@ -22,23 +22,29 @@ def run_execution_payload_processing(spec, state, execution_payload, valid=True,
yield 'execution', {'execution_valid': execution_valid}
yield 'execution_payload', execution_payload
called_new_block = False
spec.verify_execution_state_transition_ret_value = execution_valid
class TestEngine(spec.NoopExecutionEngine):
def new_block(self, payload) -> bool:
nonlocal called_new_block, execution_valid
called_new_block = True
assert payload == execution_payload
return execution_valid
if not valid:
expect_assertion_error(lambda: spec.process_execution_payload(state, execution_payload))
expect_assertion_error(lambda: spec.process_execution_payload(state, execution_payload, TestEngine()))
yield 'post', None
spec.verify_execution_state_transition_ret_value = True
return
spec.process_execution_payload(state, execution_payload)
spec.process_execution_payload(state, execution_payload, TestEngine())
# Make sure we called the engine
assert called_new_block
yield 'post', state
assert state.latest_execution_payload_header == get_execution_payload_header(spec, execution_payload)
spec.verify_execution_state_transition_ret_value = True
@with_merge_and_later
@spec_state_test