diff --git a/hdk/common/__init__.py b/hdk/common/__init__.py index ace1dd1c6..cd563aaf9 100644 --- a/hdk/common/__init__.py +++ b/hdk/common/__init__.py @@ -1,2 +1,2 @@ """HDK's module for shared data structures and code""" -from . import data_types +from . import data_types, representation diff --git a/hdk/common/representation/__init__.py b/hdk/common/representation/__init__.py new file mode 100644 index 000000000..7bdfcc4df --- /dev/null +++ b/hdk/common/representation/__init__.py @@ -0,0 +1,2 @@ +"""HDK's representation module to represent source programs""" +from . import intermediate diff --git a/hdk/common/representation/intermediate.py b/hdk/common/representation/intermediate.py new file mode 100644 index 000000000..7b1cb5433 --- /dev/null +++ b/hdk/common/representation/intermediate.py @@ -0,0 +1,53 @@ +"""File containing HDK's intermdiate representation of source programs operations""" + +from abc import ABC +from copy import deepcopy +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from ..data_types import BaseValue + + +class IntermediateNode(ABC): + """Abstract Base Class to derive from to represent source program operations""" + + inputs: List[BaseValue] + outputs: List[BaseValue] + op_args: Optional[Tuple[Any, ...]] + op_kwargs: Optional[Dict[str, Any]] + + def __init__( + self, + inputs: Iterable[BaseValue], + op_args: Optional[Tuple[Any, ...]] = None, + op_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.inputs = list(inputs) + self.op_args = op_args + self.op_kwargs = op_kwargs + + +class Add(IntermediateNode): + """Addition between two values""" + + def __init__( + self, + inputs: Iterable[BaseValue], + ) -> None: + super().__init__(inputs) + assert len(self.inputs) == 2 + + # For now copy the first input type for the output type + # We don't perform checks or enforce consistency here for now, so this is OK + self.outputs = [deepcopy(self.inputs[0])] + + +class Input(IntermediateNode): + """Node representing an input of the numpy program""" + + def __init__( + self, + inputs: Iterable[BaseValue], + ) -> None: + super().__init__(inputs) + assert len(self.inputs) == 1 + self.outputs = [deepcopy(self.inputs[0])]