feat: make both p error and global p error optional

This commit is contained in:
Umut
2022-11-18 14:32:05 +01:00
parent df8d34af9d
commit 1472c8f020
4 changed files with 34 additions and 8 deletions

View File

@@ -27,7 +27,7 @@ class Configuration:
auto_parallelize: bool
jit: bool
p_error: Optional[float]
global_p_error: float
global_p_error: Optional[float]
insecure_key_cache_location: Optional[str]
auto_adjust_rounders: bool
@@ -73,7 +73,7 @@ class Configuration:
auto_parallelize: bool = False,
jit: bool = False,
p_error: Optional[float] = None,
global_p_error: float = (1 / 100_000),
global_p_error: Optional[float] = (1 / 100_000),
auto_adjust_rounders: bool = False,
):
self.verbose = verbose
@@ -112,6 +112,8 @@ class Configuration:
configuration that is forked from self and updated using kwargs
"""
# pylint: disable=too-many-branches
result = deepcopy(self)
hints = get_type_hints(Configuration)
@@ -133,6 +135,11 @@ class Configuration:
is_correctly_typed = False
expected = "Optional[float]"
elif name == "global_p_error":
if not (value is None or isinstance(value, float)):
is_correctly_typed = False
expected = "Optional[float]"
elif name in ["show_graph", "show_mlir", "show_optimizer"]:
if not (value is None or isinstance(value, bool)):
is_correctly_typed = False
@@ -156,3 +163,5 @@ class Configuration:
# pylint: enable=protected-access
return result
# pylint: enable=too-many-branches

View File

@@ -94,10 +94,21 @@ class Server:
options.set_loop_parallelize(configuration.loop_parallelize)
options.set_dataflow_parallelize(configuration.dataflow_parallelize)
options.set_auto_parallelize(configuration.auto_parallelize)
if configuration.p_error is not None:
options.set_p_error(configuration.p_error)
else:
global_p_error_is_set = configuration.global_p_error is not None
p_error_is_set = configuration.p_error is not None
if global_p_error_is_set and p_error_is_set: # pragma: no cover
options.set_global_p_error(configuration.global_p_error)
options.set_p_error(configuration.p_error)
elif global_p_error_is_set: # pragma: no cover
options.set_global_p_error(configuration.global_p_error)
options.set_p_error(1.0)
elif p_error_is_set: # pragma: no cover
options.set_global_p_error(1.0)
options.set_p_error(configuration.p_error)
show_optimizer = (
configuration.show_optimizer

View File

@@ -74,10 +74,10 @@ Additional kwarg to `compile` function have higher precedence. So if you set an
* Whether to adjust rounders automatically.
* **p_error**: Optional[float] = None
* Error probability for individual table lookups. Overwrites **global_p_error** if set.
* Error probability for individual table lookups. If set, all table lookups will have the probability of non-exact result smaller than the set value.
* **global_p_error**: float = (1 / 100_000)
* Global error probability for the whole circuit.
* **global_p_error**: Optional[float] = (1 / 100_000)
* Global error probability for the whole circuit. If set, the whole circuit will have the probability of non-exact result smaller than the set value.
* **jit**: bool = False
* Whether to use JIT compilation.

View File

@@ -80,6 +80,12 @@ def test_configuration_fork():
"Unexpected type for keyword argument 'p_error' "
"(expected 'Optional[float]', got 'str')",
),
pytest.param(
{"global_p_error": "mamma mia"},
TypeError,
"Unexpected type for keyword argument 'global_p_error' "
"(expected 'Optional[float]', got 'str')",
),
pytest.param(
{"show_optimizer": "please"},
TypeError,