mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-09 13:38:02 -05:00
unfinished join
This commit is contained in:
@@ -1,26 +1,62 @@
|
||||
import torch
|
||||
from typing import Type
|
||||
|
||||
from zkstats.computation import IState, computation_to_model_mpc
|
||||
from zkstats.computation import IState, computation_to_model_mpc, MagicNumber
|
||||
|
||||
from .utils import compile_and_run_mpspdz
|
||||
|
||||
|
||||
|
||||
def computation(state: IState, args: list[torch.Tensor]):
|
||||
x = args[0]
|
||||
# y = args[1]
|
||||
# z = args[2]
|
||||
return state.mean(x)
|
||||
columns_0 = [args[0], args[1]]
|
||||
columns_1 = [args[2], args[3]]
|
||||
x_key = columns_0[0]
|
||||
y_key = columns_1[0]
|
||||
num_rows_x = x_key.size(0)
|
||||
num_cols_y = len(columns_1)
|
||||
|
||||
# Create a tensor for each new y columns
|
||||
new_y = [torch.where(x_key == 0, 0, 0) for _ in columns_1]
|
||||
for i in range(num_rows_x):
|
||||
# i = 0, one_hot = [1, 0, 0, 0]
|
||||
# i = 1, one_hot = [0, 1, 0, 0]
|
||||
one_hot = torch.arange(num_rows_x) == i
|
||||
# new_y[i] = 1
|
||||
# [1, 4] -> [0, 1]
|
||||
mask = y_key == x_key[i]
|
||||
# is_mask_nonzero = torch.sum(mask) != 0
|
||||
for k in range(num_cols_y):
|
||||
# [0, 1] * [5, 4] -> [0, 4]
|
||||
# sum([0, 4]) -> 4
|
||||
matched_value = torch.sum(mask * columns_1[k])
|
||||
# [1, 0, 0, 0] * 4 -> [4, 0, 0, 0]
|
||||
entry = one_hot * matched_value
|
||||
new_y[k] = entry + new_y[k]
|
||||
# new_y[k][i] = matched_value + new_y[k][i]
|
||||
return list(columns_0) + [torch.where(col == 0, MagicNumber, col) for col in new_y]
|
||||
|
||||
|
||||
# two tensor stuffs
|
||||
def test_computation(tmp_path):
|
||||
data_1 = torch.tensor(
|
||||
[32, 8, 8],
|
||||
x_0 = torch.tensor(
|
||||
[1, 2, 3],
|
||||
dtype = torch.float32,
|
||||
).reshape(-1, 1)
|
||||
x_1 = torch.tensor(
|
||||
[180, 160, 183],
|
||||
dtype = torch.float32,
|
||||
).reshape(-1, 1)
|
||||
y_0 = torch.tensor(
|
||||
[1, 2, 4],
|
||||
dtype = torch.float32,
|
||||
).reshape(-1, 1)
|
||||
y_1 = torch.tensor(
|
||||
[50, 40, 75],
|
||||
dtype = torch.float32,
|
||||
).reshape(-1, 1)
|
||||
|
||||
data = (x_0, x_1, y_0, y_1)
|
||||
|
||||
state, Model = computation_to_model_mpc(computation)
|
||||
|
||||
compile_and_run_mpspdz(Model, tuple([data_1]), tmp_path)
|
||||
res = compile_and_run_mpspdz(Model, data, tmp_path)
|
||||
print(f"!@# res={res}")
|
||||
|
||||
@@ -12,12 +12,12 @@ from zkstats.arithc_to_bristol import parse_arithc_json
|
||||
from zkstats.backends.mpspdz import generate_mpspdz_circuit, generate_mpspdz_inputs_for_party, run_mpspdz_circuit, tensors_to_circom_mpspdz_inputs
|
||||
|
||||
|
||||
CIRCOM_2_ARITHC_PROJECT_ROOT = Path('/path/to/circom-2-arithc-project-root')
|
||||
MP_SPDZ_PROJECT_ROOT = Path('/path/to/mp-spdz-project-root')
|
||||
CIRCOM_2_ARITHC_PROJECT_ROOT = Path('/Users/mhchia/projects/work/pse/circom-2-arithc')
|
||||
MP_SPDZ_PROJECT_ROOT = Path('/Users/mhchia/projects/work/pse/MP-SPDZ')
|
||||
|
||||
# For generalized multiple tensor input
|
||||
|
||||
def compile_and_run_mpspdz(model_type: Type[nn.Module], data: tuple[torch.Tensor], tmp_path: Path):
|
||||
def compile_and_run_mpspdz(model_type: Type[nn.Module], data: tuple[torch.Tensor, ...], tmp_path: Path):
|
||||
# output_path = tmp_path
|
||||
# Don't use tmp_path for now for easier debugging
|
||||
# So you should see all generated files in `output_path`
|
||||
@@ -161,7 +161,7 @@ def run_torch_model(model_type: Type[nn.Module], data: tuple[torch.Tensor]) -> t
|
||||
return output_torch.reshape(-1)
|
||||
|
||||
|
||||
def torch_model_to_onnx(model_type: Type[nn.Module], data: tuple[torch.Tensor], output_onnx_path: Path):
|
||||
def torch_model_to_onnx(model_type: Type[nn.Module], data: tuple[torch.Tensor, ...], output_onnx_path: Path):
|
||||
model = model_type()
|
||||
input_names = []
|
||||
# dynamic_axes = {}
|
||||
|
||||
@@ -135,6 +135,19 @@ class IState(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], join_index: int) -> list[torch.Tensor]:
|
||||
"""
|
||||
Join two columns based on the join index. Two columns The behavior should conform to
|
||||
[pandas.DataFrame.join](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.join.html) in pandas library.
|
||||
|
||||
:param columns_0: A list of tensors
|
||||
:param columns_1: A list of tensors
|
||||
:param join_index: The index to join on
|
||||
:return: A list of tensors
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class State(IState):
|
||||
"""
|
||||
@@ -202,6 +215,9 @@ class State(IState):
|
||||
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.where(_filter, x, x-x+MagicNumber)
|
||||
|
||||
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], join_index: int) -> list[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
|
||||
if self.current_op_index is None:
|
||||
# for prover
|
||||
@@ -414,6 +430,74 @@ class MPCState(IState):
|
||||
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
|
||||
x_key = columns_0[index_0]
|
||||
y_key = columns_1[index_1]
|
||||
num_rows_x = x_key.size(0)
|
||||
num_cols_y = len(columns_1)
|
||||
|
||||
# Create a tensor for each new y columns
|
||||
new_y = [torch.where(x_key == 0, 0, 0) for _ in columns_1]
|
||||
for i in range(num_rows_x):
|
||||
# i = 0, one_hot = [1, 0, 0, 0]
|
||||
# i = 1, one_hot = [0, 1, 0, 0]
|
||||
one_hot = torch.arange(num_rows_x) == i
|
||||
# new_y[i] = 1
|
||||
# [1, 4] -> [0, 1]
|
||||
mask = y_key == x_key[i]
|
||||
# is_mask_nonzero = torch.sum(mask) != 0
|
||||
for k in range(num_cols_y):
|
||||
# [0, 1] * [5, 4] -> [0, 4]
|
||||
# sum([0, 4]) -> 4
|
||||
matched_value = torch.sum(mask * columns_1[k])
|
||||
# [1, 0, 0, 0] * 4 -> [4, 0, 0, 0]
|
||||
entry = one_hot * matched_value
|
||||
new_y[k] = entry + new_y[k]
|
||||
return list(columns_0) + [torch.where(col == 0, MagicNumber, col) for col in new_y]
|
||||
|
||||
# # NOTE: naive way
|
||||
# def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
|
||||
# x_key = columns_0[index_0]
|
||||
# y_key = columns_1[index_1]
|
||||
# num_rows_x = x_key.size(0)
|
||||
# num_cols_y = len(columns_1)
|
||||
|
||||
# # Create a tensor to hold the new y values
|
||||
# new_y = [torch.where(x_key, x_key, MagicNumber) for _ in columns_1]
|
||||
|
||||
# for i in range(num_rows_x):
|
||||
# # Create a mask for matching rows
|
||||
# mask = y_key == x_key[i]
|
||||
# for k in range(num_cols_y):
|
||||
# new_value = torch.sum(mask * columns_1[k])
|
||||
# # Update new_y only where mask is not zero
|
||||
# is_mask_nonzero = torch.sum(mask) != 0
|
||||
# new_y[k][i] = torch.where(is_mask_nonzero, new_value, new_y[k][i])
|
||||
|
||||
# return list(columns_0) + new_y
|
||||
|
||||
# # NOTE: matmul way
|
||||
# def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
|
||||
# x_key = columns_0[index_0]
|
||||
# y_key = columns_1[index_1]
|
||||
# num_cols_y = len(columns_1)
|
||||
|
||||
# # Create a tensor to hold the new y values
|
||||
# new_y = [torch.full_like(x_key, MagicNumber) for _ in columns_1]
|
||||
|
||||
# # Expand dimensions for broadcasting
|
||||
# x_key_expanded = x_key.unsqueeze(1) # Shape: [num_rows_x, 1]
|
||||
# y_key_expanded = y_key.unsqueeze(0) # Shape: [1, num_rows_y]
|
||||
|
||||
# # Create a match matrix where x_key == y_key
|
||||
# match_matrix = (x_key_expanded == y_key_expanded).float() # Shape: [num_rows_x, num_rows_y]
|
||||
|
||||
# # Compute new_y for each column in y
|
||||
# for k in range(num_cols_y):
|
||||
# matched_y = torch.matmul(match_matrix, columns_1[k].unsqueeze(1)).squeeze(1) # Shape: [num_rows_x]
|
||||
# new_y[k] = torch.where(matched_y != 0, matched_y, new_y[k])
|
||||
# return list(columns_0) + new_y
|
||||
|
||||
|
||||
def computation_to_model_mpc(computation: TComputation) -> tuple[State, Type[IModel]]:
|
||||
state = MPCState()
|
||||
|
||||
@@ -5,6 +5,7 @@ import typing
|
||||
|
||||
import numpy as np
|
||||
import keras
|
||||
import torch
|
||||
|
||||
|
||||
# <KerasTensor shape=(), dtype=float32, sparse=False, name=keras_tensor_10>
|
||||
@@ -23,10 +24,35 @@ class Input:
|
||||
name: typing.Optional[str]
|
||||
# If it's a constant, value is the value of the constant. Else, it's None
|
||||
value: typing.Optional[float]
|
||||
# is it keras_tensor in form of no shape i.e. shape = ()
|
||||
# is it keras_tensor in form of no shape i.e. shape = ()
|
||||
is_keras_constant: bool
|
||||
|
||||
|
||||
def dict_to_tensor(data):
|
||||
if data['class_name'] != '__numpy__':
|
||||
raise ValueError("Unsupported class_name")
|
||||
|
||||
value = data['config']['value']
|
||||
dtype = data['config']['dtype']
|
||||
|
||||
# Map the dtype string to a PyTorch dtype
|
||||
dtype_map = {
|
||||
'float32': torch.float32,
|
||||
'float64': torch.float64,
|
||||
'int32': torch.int32,
|
||||
'int64': torch.int64
|
||||
}
|
||||
|
||||
if dtype not in dtype_map:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
tensor_dtype = dtype_map[dtype]
|
||||
|
||||
# Convert the list to a PyTorch tensor with the specified dtype
|
||||
tensor = torch.tensor(value, dtype=tensor_dtype)
|
||||
return tensor
|
||||
|
||||
|
||||
# read each layer in a model and convert it to a class called Layer
|
||||
@dataclass
|
||||
class Layer:
|
||||
@@ -48,7 +74,7 @@ class Layer:
|
||||
self.config = _config
|
||||
self.inputs = []
|
||||
list_inputs = _config['node_inputs']
|
||||
|
||||
|
||||
|
||||
index = 0
|
||||
for ele_name in list_inputs:
|
||||
@@ -66,21 +92,29 @@ class Layer:
|
||||
# if it's keras tensor resulting in constant, get the shape from non-constant input
|
||||
if input_shape == ():
|
||||
# if there are more than 1 inputs like `TFAdd`, we need to get the shape of the other input
|
||||
if len(_inputs)==2 and len(_inputs[1-index].shape)>=1:
|
||||
input_shape = (_inputs[1-index]).shape
|
||||
else:
|
||||
input_shape =(1,)
|
||||
# if len(_inputs)==2 and len(_inputs[1-index].shape)>=1:
|
||||
# input_shape = (_inputs[1-index]).shape
|
||||
# else:
|
||||
input_shape =(1,)
|
||||
is_keras_constant = True
|
||||
index += 1
|
||||
# it's constant. assume it's a float
|
||||
# FIXME: a constant can be a tensor with multiple dimensions, but for now we assume
|
||||
# it's constant.
|
||||
else:
|
||||
name = None
|
||||
value = float(config_ele)
|
||||
if len(_inputs)>0 and len(_inputs[0].shape)>=1:
|
||||
input_shape = (_inputs[0]).shape
|
||||
# '/Constant_2_output_0': {'class_name': '__numpy__', 'config': {'value': [1.0, 0.0, 0.0], 'dtype': 'float32'}}
|
||||
if isinstance(config_ele, dict) and config_ele["class_name"]=='__numpy__':
|
||||
value = config_ele['config']['value']
|
||||
value_in_tensor = dict_to_tensor(config_ele)
|
||||
input_shape = value_in_tensor.shape
|
||||
# '/Constant_output_0': 0
|
||||
else:
|
||||
input_shape =(1,)
|
||||
|
||||
value = float(config_ele)
|
||||
# if len(_inputs)>0 and len(_inputs[0].shape)>=1:
|
||||
# input_shape = (_inputs[0]).shape
|
||||
# else:
|
||||
input_shape = (1,)
|
||||
|
||||
self.inputs.append(
|
||||
Input(
|
||||
is_constant=not is_non_constant,
|
||||
|
||||
@@ -29,8 +29,9 @@ from zkstats.onnx2circom.onnx2keras.layers import (
|
||||
TFOr,
|
||||
TFWhere,
|
||||
TFAbs,
|
||||
TFGather,
|
||||
TFConcat
|
||||
TFGather,
|
||||
TFConcat,
|
||||
TFIdentity,
|
||||
# TFArgMax,
|
||||
# TFArgMin,
|
||||
)
|
||||
@@ -65,8 +66,9 @@ SUPPORTED_OPS = [
|
||||
TFCast,
|
||||
TFWhere,
|
||||
TFAbs,
|
||||
TFGather,
|
||||
TFConcat
|
||||
TFGather,
|
||||
TFConcat,
|
||||
TFIdentity,
|
||||
# TFErf,
|
||||
]
|
||||
|
||||
@@ -100,9 +102,9 @@ def get_component_args_values(layer: Layer) -> typing.Dict[str, typing.Any]:
|
||||
return {'e': 2, 'nInputs': num_elements_in_input_0}
|
||||
if is_in_ops(layer.op, [TFReduceSum, TFReduceMean, TFReduceMax, TFReduceMin]):
|
||||
return {'nInputs': num_elements_in_input_0}
|
||||
if is_in_ops(layer.op, [TFNot, TFCast, TFWhere, TFAbs, TFGather]):
|
||||
if is_in_ops(layer.op, [TFNot, TFCast, TFWhere, TFAbs, TFGather, TFIdentity]):
|
||||
return {'nElements': num_elements_in_input_0}
|
||||
# 2 inputs operations
|
||||
# 2 inputs operations
|
||||
if is_in_ops(layer.op, [TFAdd, TFSub, TFMul, TFDiv, TFEqual, TFGreater, TFLess, TFAnd, TFOr]):
|
||||
input_1 = inputs[1]
|
||||
input_1_shape = input_1.shape
|
||||
@@ -198,13 +200,15 @@ def transpile(templates: dict[str, Template], filename: str, output_dir: str = '
|
||||
# Handle right hand side when it's keras tensor
|
||||
input_name = _input.name
|
||||
input_shape = _input.shape
|
||||
# This input can only be a constant, a model input, or an output from another component
|
||||
if _input.is_constant:
|
||||
# If this input is a constant, use the value of the constant directly as the right hand side
|
||||
from_component_name = None
|
||||
# FIXME: if _input.value is a float, we need to scale it by 10^dec
|
||||
# Scale the float value by 10^dec
|
||||
scaled = int(_input.value * 10 ** dec)
|
||||
from_component_signal_name = str(scaled)
|
||||
rhs_dim = 0
|
||||
rhs_dim = len(input_shape)
|
||||
elif model.is_model_input(input_name) is True:
|
||||
# If this input is from `input_layer`, use the original tensor name for it
|
||||
from_component_name = None
|
||||
|
||||
@@ -281,7 +281,7 @@ template TFGather(nElements) {
|
||||
}
|
||||
|
||||
out <== out_till[nElements];
|
||||
|
||||
|
||||
}
|
||||
|
||||
template TFConcat(nElements_0, nElements_1){
|
||||
@@ -294,4 +294,12 @@ template TFConcat(nElements_0, nElements_1){
|
||||
for (var i = 0; i< nElements_1; i++){
|
||||
out[nElements_0+i] <== in_1[i]+0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template TFIdentity(nElements) {
|
||||
signal input in[nElements];
|
||||
signal output out[nElements];
|
||||
for (var i = 0; i<nElements; i++){
|
||||
out[i] <== in[i] + 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
import logging
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
import keras
|
||||
|
||||
from ..utils.op_registry import OPERATOR
|
||||
from .dimension_utils import intfloat_to_list
|
||||
@@ -171,7 +171,7 @@ class TFUpsample():
|
||||
def __call__(self, inputs):
|
||||
return tf.image.resize(inputs, self.scale, method=self.method)
|
||||
|
||||
|
||||
|
||||
@OPERATOR.register_operator("Constant")
|
||||
class TFConstant():
|
||||
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
|
||||
@@ -242,12 +242,26 @@ class TFGemm():
|
||||
return self.dense(inputs)
|
||||
|
||||
@OPERATOR.register_operator("Identity")
|
||||
class TFIdentity():
|
||||
def __init__(self, *args, **kwargs):
|
||||
class TFIdentity(keras.layers.Layer):
|
||||
def __init__(self,tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.tensor_grap = tensor_grap
|
||||
self.node_weights = node_weights
|
||||
self.node_inputs = node_inputs
|
||||
self.node_attribute = node_attribute
|
||||
|
||||
def __call__(self, inputs):
|
||||
return inputs
|
||||
def call(self, input, *args, **kwargs):
|
||||
return input
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
"tensor_grap":self.tensor_grap,
|
||||
'node_weights':self.node_weights,
|
||||
'node_inputs':self.node_inputs,
|
||||
'node_attribute':self.node_attribute,
|
||||
})
|
||||
return config
|
||||
|
||||
@OPERATOR.register_operator("Dropout")
|
||||
class TFDropout():
|
||||
@@ -310,14 +324,14 @@ class TFCast(keras.layers.Layer):
|
||||
# inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def get_config(self):
|
||||
config = super().get_config()
|
||||
config.update({
|
||||
"tensor_grap":self.tensor_grap,
|
||||
'node_weights':self.node_weights,
|
||||
'node_inputs':self.node_inputs,
|
||||
'node_attribute':self.node_attribute,
|
||||
'node_attribute':self.node_attribute,
|
||||
"cast_to": self.cast_to
|
||||
})
|
||||
return config
|
||||
Reference in New Issue
Block a user