mirror of
https://github.com/socathie/circomlib-ml.git
synced 2026-01-08 21:48:06 -05:00
217 lines
6.7 KiB
Plaintext
217 lines
6.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from mpmath import mp\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"# from https://github.com/DKenefake/OptimalPoly\n",
|
|
"\n",
|
|
"def bisection_search(f, low:float, high:float):\n",
|
|
" \"\"\"\n",
|
|
" A root finding method that does not rely on derivatives\n",
|
|
"\n",
|
|
" :param f: a function f: X -> R\n",
|
|
" :param low: the lower bracket\n",
|
|
" :param high: the upper limit bracket\n",
|
|
" :return: the location of the root, e.g. f(mid) ~ 0\n",
|
|
" \"\"\"\n",
|
|
" # flip high and low if out of order\n",
|
|
" if f(high) < f(low):\n",
|
|
" low, high = high, low\n",
|
|
"\n",
|
|
" # find mid point\n",
|
|
" mid = .5 * (low + high)\n",
|
|
"\n",
|
|
" while True:\n",
|
|
"\n",
|
|
" # bracket up\n",
|
|
" if f(mid) < 0:\n",
|
|
" low = mid\n",
|
|
" # braket down\n",
|
|
" else:\n",
|
|
" high = mid\n",
|
|
"\n",
|
|
" # update mid point\n",
|
|
" mid = .5 * (high + low)\n",
|
|
"\n",
|
|
" # break if condition met\n",
|
|
" if abs(high - low) < 10 ** (-(mp.dps / 2)):\n",
|
|
" break\n",
|
|
"\n",
|
|
" return mid\n",
|
|
"\n",
|
|
"\n",
|
|
"def concave_max(f, low:float, high:float):\n",
|
|
" \"\"\"\n",
|
|
" Forms a lambda for the approximate derivative and finds the root\n",
|
|
"\n",
|
|
" :param f: a function f: X -> R\n",
|
|
" :param low: the lower bracket\n",
|
|
" :param high: the upper limit bracket\n",
|
|
" :return: the location of the root f'(mid) ~ 0\n",
|
|
" \"\"\"\n",
|
|
" # create an approximate derivative expression\n",
|
|
" scale = high - low\n",
|
|
"\n",
|
|
" h = mp.mpf('0.' + ''.join(['0' for i in range(int(mp.dps / 1.5))]) + '1') * scale\n",
|
|
" df = lambda x: (f(x + h) - f(x - h)) / (2.0 * h)\n",
|
|
"\n",
|
|
" return bisection_search(df, low, high)\n",
|
|
"\n",
|
|
"def chev_points(n:int, lower:float = -1, upper:float = 1):\n",
|
|
" \"\"\"\n",
|
|
" Generates a set of chebychev points spaced in the range [lower, upper]\n",
|
|
" :param n: number of points\n",
|
|
" :param lower: lower limit\n",
|
|
" :param upper: upper limit\n",
|
|
" :return: a list of multipressison chebychev points that are in the range [lower, upper]\n",
|
|
" \"\"\"\n",
|
|
" #generate chebeshev points on a range [-1, 1]\n",
|
|
" index = np.arange(1, n+1)\n",
|
|
" range_ = abs(upper - lower)\n",
|
|
" return [(.5*(mp.cos((2*i-1)/(2*n)*mp.pi)+1))*range_ + lower for i in index]\n",
|
|
"\n",
|
|
"\n",
|
|
"def remez(func, n_degree:int, lower:float=-1, upper:float=1, max_iter:int = 10):\n",
|
|
" \"\"\"\n",
|
|
" :param func: a function (or lambda) f: X -> R\n",
|
|
" :param n_degree: the degree of the polynomial to approximate the function f\n",
|
|
" :param lower: lower range of the approximation\n",
|
|
" :param upper: upper range of the approximation\n",
|
|
" :return: the polynomial coefficients, and an approximate maximum error associated with this approximation\n",
|
|
" \"\"\"\n",
|
|
" # initialize the node points\n",
|
|
"\n",
|
|
" x_points = chev_points(n_degree + 2, lower, upper)\n",
|
|
"\n",
|
|
" A = mp.matrix(n_degree + 2)\n",
|
|
" coeffs = np.zeros(n_degree + 2)\n",
|
|
"\n",
|
|
" # place in the E column\n",
|
|
" mean_error = float('inf')\n",
|
|
"\n",
|
|
" for i in range(n_degree + 2):\n",
|
|
" A[i, n_degree + 1] = (-1) ** (i + 1)\n",
|
|
"\n",
|
|
" for i in range(max_iter):\n",
|
|
"\n",
|
|
" # build the system\n",
|
|
" vander = np.polynomial.chebyshev.chebvander(x_points, n_degree)\n",
|
|
"\n",
|
|
" for i in range(n_degree + 2):\n",
|
|
" for j in range(n_degree + 1):\n",
|
|
" A[i, j] = vander[i, j]\n",
|
|
"\n",
|
|
" b = mp.matrix([func(x) for x in x_points])\n",
|
|
" l = mp.lu_solve(A, b)\n",
|
|
"\n",
|
|
" coeffs = l[:-1]\n",
|
|
"\n",
|
|
" # build the residual expression\n",
|
|
" r_i = lambda x: (func(x) - np.polynomial.chebyshev.chebval(x, coeffs))\n",
|
|
"\n",
|
|
" interval_list = list(zip(x_points, x_points[1:]))\n",
|
|
" # interval_list = [[x_points[i], x_points[i+1]] for i in range(len(x_points)-1)]\n",
|
|
"\n",
|
|
" intervals = [upper]\n",
|
|
" intervals.extend([bisection_search(r_i, *i) for i in interval_list])\n",
|
|
" intervals.append(lower)\n",
|
|
"\n",
|
|
" extermum_interval = [[intervals[i], intervals[i + 1]] for i in range(len(intervals) - 1)]\n",
|
|
"\n",
|
|
" extremums = [concave_max(r_i, *i) for i in extermum_interval]\n",
|
|
"\n",
|
|
" extremums[0] = mp.mpf(upper)\n",
|
|
" extremums[-1] = mp.mpf(lower)\n",
|
|
"\n",
|
|
" errors = [abs(r_i(i)) for i in extremums]\n",
|
|
" mean_error = np.mean(errors)\n",
|
|
"\n",
|
|
" if np.max([abs(error - mean_error) for error in errors]) < 0.000001 * mean_error:\n",
|
|
" break\n",
|
|
"\n",
|
|
" x_points = extremums\n",
|
|
"\n",
|
|
" return [float(i) for i in np.polynomial.chebyshev.cheb2poly(coeffs)], float(mean_error)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['0.502073021', '0.198695283', '-0.001570683', '-0.004001354']"
|
|
]
|
|
},
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"function = lambda x: mp.sigmoid(x)\n",
|
|
"poly_coeffs, max_error = remez(function, 3, -5, 5)\n",
|
|
"[np.format_float_positional(c, precision=9) for c in poly_coeffs]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"['0.006769816', '0.554670504', '-0.009411195', '-0.014187547']"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"function = lambda x: mp.tanh(x)\n",
|
|
"poly_coeffs, max_error = remez(function, 3, -5, 5)\n",
|
|
"[np.format_float_positional(c, precision=9) for c in poly_coeffs]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "sklearn",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|