mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-15 15:15:06 -05:00
296 lines
9.1 KiB
Python
296 lines
9.1 KiB
Python
from multiprocessing import Pool
|
|
from functools import partial
|
|
|
|
from sage.all import ceil, floor
|
|
|
|
from .io import Logging
|
|
|
|
|
|
class local_minimum_base:
|
|
"""
|
|
An iterator context for finding a local minimum using binary search.
|
|
|
|
We use the immediate neighborhood of a point to decide the next direction to go into (gradient
|
|
descent style), so the algorithm is not plain binary search (see ``update()`` function.)
|
|
|
|
.. note :: We combine an iterator and a context to give the caller access to the result.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
start,
|
|
stop,
|
|
smallerf=lambda x, best: x <= best,
|
|
suppress_bounds_warning=False,
|
|
log_level=5,
|
|
):
|
|
"""
|
|
Create a fresh local minimum search context.
|
|
|
|
:param start: starting point
|
|
:param stop: end point (exclusive)
|
|
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``.
|
|
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal
|
|
|
|
"""
|
|
|
|
if stop < start:
|
|
raise ValueError(f"Incorrect bounds {start} > {stop}.")
|
|
|
|
self._suppress_bounds_warning = suppress_bounds_warning
|
|
self._log_level = log_level
|
|
self._start = start
|
|
self._stop = stop - 1
|
|
self._initial_bounds = (start, stop - 1)
|
|
self._smallerf = smallerf
|
|
# abs(self._direction) == 2: binary search step
|
|
# abs(self._direction) == 1: gradient descent direction
|
|
self._direction = -1 # going down
|
|
self._last_x = None
|
|
self._next_x = self._stop
|
|
self._best = (None, None)
|
|
self._all_x = set()
|
|
|
|
def __enter__(self):
|
|
""" """
|
|
return self
|
|
|
|
def __exit__(self, type, value, traceback):
|
|
""" """
|
|
pass
|
|
|
|
def __iter__(self):
|
|
""" """
|
|
return self
|
|
|
|
def __next__(self):
|
|
abort = False
|
|
if self._next_x is None:
|
|
abort = True # we're told to abort
|
|
elif self._next_x in self._all_x:
|
|
abort = True # we're looping
|
|
elif self._next_x < self._initial_bounds[0] or self._initial_bounds[1] < self._next_x:
|
|
abort = True # we're stepping out of bounds
|
|
|
|
if not abort:
|
|
self._last_x = self._next_x
|
|
self._next_x = None
|
|
return self._last_x
|
|
|
|
if self._best[0] in self._initial_bounds and not self._suppress_bounds_warning:
|
|
# We warn the user if the optimal solution is at the edge and thus possibly not optimal.
|
|
Logging.log(
|
|
"bins",
|
|
self._log_level,
|
|
f'warning: "optimal" solution {self._best[0]} matches a bound ∈ {self._initial_bounds}.',
|
|
)
|
|
|
|
raise StopIteration
|
|
|
|
@property
|
|
def x(self):
|
|
return self._best[0]
|
|
|
|
@property
|
|
def y(self):
|
|
return self._best[1]
|
|
|
|
def update(self, res):
|
|
"""
|
|
|
|
TESTS:
|
|
|
|
We keep cache old inputs in ``_all_x`` to prevent infinite loops::
|
|
|
|
>>> from estimator.util import binary_search
|
|
>>> from estimator.cost import Cost
|
|
>>> f = lambda x, log_level=1: Cost(rop=1) if x >= 19 else Cost(rop=2)
|
|
>>> binary_search(f, 10, 30, "x")
|
|
rop: 1
|
|
|
|
"""
|
|
|
|
Logging.log("bins", self._log_level, f"({self._last_x}, {repr(res)})")
|
|
|
|
self._all_x.add(self._last_x)
|
|
|
|
# We got nothing yet
|
|
if self._best[0] is None:
|
|
self._best = self._last_x, res
|
|
|
|
# We found something better
|
|
if res is not False and self._smallerf(res, self._best[1]):
|
|
# store it
|
|
self._best = self._last_x, res
|
|
|
|
# if it's a result of a long jump figure out the next direction
|
|
if abs(self._direction) != 1:
|
|
self._direction = -1
|
|
self._next_x = self._last_x - 1
|
|
# going down worked, so let's keep on doing that.
|
|
elif self._direction == -1:
|
|
self._direction = -2
|
|
self._stop = self._last_x
|
|
self._next_x = ceil((self._start + self._stop) / 2)
|
|
# going up worked, so let's keep on doing that.
|
|
elif self._direction == 1:
|
|
self._direction = 2
|
|
self._start = self._last_x
|
|
self._next_x = floor((self._start + self._stop) / 2)
|
|
else:
|
|
# going downwards didn't help, let's try up
|
|
if self._direction == -1:
|
|
self._direction = 1
|
|
self._next_x = self._last_x + 2
|
|
# going up didn't help either, so we stop
|
|
elif self._direction == 1:
|
|
self._next_x = None
|
|
# it got no better in a long jump, half the search space and try again
|
|
elif self._direction == -2:
|
|
self._start = self._last_x
|
|
self._next_x = ceil((self._start + self._stop) / 2)
|
|
elif self._direction == 2:
|
|
self._stop = self._last_x
|
|
self._next_x = floor((self._start + self._stop) / 2)
|
|
|
|
# We are repeating ourselves, time to stop
|
|
if self._next_x == self._last_x:
|
|
self._next_x = None
|
|
|
|
|
|
class local_minimum(local_minimum_base):
|
|
"""
|
|
An iterator context for finding a local minimum using binary search.
|
|
|
|
We use the neighborhood of a point to decide the next direction to go into (gradient descent
|
|
style), so the algorithm is not plain binary search (see ``update()`` function.)
|
|
|
|
We also zoom out by a factor ``precision``, find an approximate local minimum and then
|
|
search the neighbourhood for the smallest value.
|
|
|
|
.. note :: We combine an iterator and a context to give the caller access to the result.
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
start,
|
|
stop,
|
|
precision=1,
|
|
smallerf=lambda x, best: x <= best,
|
|
suppress_bounds_warning=False,
|
|
log_level=5,
|
|
):
|
|
"""
|
|
Create a fresh local minimum search context.
|
|
|
|
:param start: starting point
|
|
:param stop: end point (exclusive)
|
|
:param precision: only consider every ``precision``-th value in the main loop
|
|
:param smallerf: a function to decide if ``lhs`` is smaller than ``rhs``.
|
|
:param suppress_bounds_warning: do not warn if a boundary is picked as optimal
|
|
|
|
"""
|
|
self._precision = precision
|
|
self._orig_bounds = (start, stop)
|
|
start = ceil(start / precision)
|
|
stop = floor(stop / precision)
|
|
local_minimum_base.__init__(self, start, stop, smallerf, suppress_bounds_warning, log_level)
|
|
|
|
def __next__(self):
|
|
x = local_minimum_base.__next__(self)
|
|
return x * self._precision
|
|
|
|
@property
|
|
def x(self):
|
|
return self._best[0] * self._precision
|
|
|
|
@property
|
|
def neighborhood(self):
|
|
"""
|
|
An iterator over the neighborhood of the currently best value.
|
|
"""
|
|
|
|
start, stop = self._orig_bounds
|
|
|
|
for x in range(max(start, self.x - self._precision), min(stop, self.x + self._precision)):
|
|
yield x
|
|
|
|
|
|
def binary_search(
|
|
f, start, stop, param, step=1, smallerf=lambda x, best: x <= best, log_level=5, *args, **kwds
|
|
):
|
|
"""
|
|
Searches for the best value in the interval [start,stop] depending on the given comparison function.
|
|
|
|
:param start: start of range to search
|
|
:param stop: stop of range to search (exclusive)
|
|
:param param: the parameter to modify when calling `f`
|
|
:param smallerf: comparison is performed by evaluating ``smallerf(current, best)``
|
|
:param step: initially only consider every `steps`-th value
|
|
"""
|
|
|
|
with local_minimum(start, stop + 1, step, smallerf=smallerf, log_level=log_level) as it:
|
|
for x in it:
|
|
kwds_ = dict(kwds)
|
|
kwds_[param] = x
|
|
it.update(f(*args, **kwds_))
|
|
|
|
for x in it.neighborhood:
|
|
kwds_ = dict(kwds)
|
|
kwds_[param] = x
|
|
it.update(f(*args, **kwds_))
|
|
|
|
return it.y
|
|
|
|
|
|
def _batch_estimatef(f, x, log_level=0, f_repr=None):
|
|
y = f(x)
|
|
if f_repr is None:
|
|
f_repr = repr(f)
|
|
Logging.log("batch", log_level, f"f: {f_repr}")
|
|
Logging.log("batch", log_level, f"x: {x}")
|
|
Logging.log("batch", log_level, f"f(x): {repr(y)}")
|
|
return y
|
|
|
|
|
|
def f_name(f):
|
|
try:
|
|
return f.__name__
|
|
except AttributeError:
|
|
return repr(f)
|
|
|
|
|
|
def batch_estimate(params, algorithm, jobs=1, log_level=0, **kwds):
|
|
from .lwe_parameters import LWEParameters
|
|
|
|
if isinstance(params, LWEParameters):
|
|
params = (params,)
|
|
try:
|
|
iter(algorithm)
|
|
except TypeError:
|
|
algorithm = (algorithm,)
|
|
|
|
tasks = []
|
|
|
|
for x in params:
|
|
for f in algorithm:
|
|
tasks.append((partial(f, **kwds), x, log_level, f_name(f)))
|
|
|
|
if jobs == 1:
|
|
res = {}
|
|
for f, x, lvl, f_repr in tasks:
|
|
y = _batch_estimatef(f, x, lvl, f_repr)
|
|
res[f_repr, x] = y
|
|
else:
|
|
pool = Pool(jobs)
|
|
res = pool.starmap(_batch_estimatef, tasks)
|
|
res = dict([((f_repr, x), res[i]) for i, (f, x, _, f_repr) in enumerate(tasks)])
|
|
|
|
ret = dict()
|
|
for f, x in res:
|
|
ret[x] = ret.get(x, dict())
|
|
ret[x][f] = res[f, x]
|
|
|
|
return ret
|