unfinished join

This commit is contained in:
mhchia
2024-06-14 18:42:48 +08:00
parent d41615451f
commit 59ff7a7f9e
7 changed files with 222 additions and 42 deletions

View File

@@ -1,26 +1,62 @@
import torch
from typing import Type
from zkstats.computation import IState, computation_to_model_mpc
from zkstats.computation import IState, computation_to_model_mpc, MagicNumber
from .utils import compile_and_run_mpspdz
def computation(state: IState, args: list[torch.Tensor]):
x = args[0]
# y = args[1]
# z = args[2]
return state.mean(x)
columns_0 = [args[0], args[1]]
columns_1 = [args[2], args[3]]
x_key = columns_0[0]
y_key = columns_1[0]
num_rows_x = x_key.size(0)
num_cols_y = len(columns_1)
# Create a tensor for each new y columns
new_y = [torch.where(x_key == 0, 0, 0) for _ in columns_1]
for i in range(num_rows_x):
# i = 0, one_hot = [1, 0, 0, 0]
# i = 1, one_hot = [0, 1, 0, 0]
one_hot = torch.arange(num_rows_x) == i
# new_y[i] = 1
# [1, 4] -> [0, 1]
mask = y_key == x_key[i]
# is_mask_nonzero = torch.sum(mask) != 0
for k in range(num_cols_y):
# [0, 1] * [5, 4] -> [0, 4]
# sum([0, 4]) -> 4
matched_value = torch.sum(mask * columns_1[k])
# [1, 0, 0, 0] * 4 -> [4, 0, 0, 0]
entry = one_hot * matched_value
new_y[k] = entry + new_y[k]
# new_y[k][i] = matched_value + new_y[k][i]
return list(columns_0) + [torch.where(col == 0, MagicNumber, col) for col in new_y]
# two tensor stuffs
def test_computation(tmp_path):
data_1 = torch.tensor(
[32, 8, 8],
x_0 = torch.tensor(
[1, 2, 3],
dtype = torch.float32,
).reshape(-1, 1)
x_1 = torch.tensor(
[180, 160, 183],
dtype = torch.float32,
).reshape(-1, 1)
y_0 = torch.tensor(
[1, 2, 4],
dtype = torch.float32,
).reshape(-1, 1)
y_1 = torch.tensor(
[50, 40, 75],
dtype = torch.float32,
).reshape(-1, 1)
data = (x_0, x_1, y_0, y_1)
state, Model = computation_to_model_mpc(computation)
compile_and_run_mpspdz(Model, tuple([data_1]), tmp_path)
res = compile_and_run_mpspdz(Model, data, tmp_path)
print(f"!@# res={res}")

View File

@@ -12,12 +12,12 @@ from zkstats.arithc_to_bristol import parse_arithc_json
from zkstats.backends.mpspdz import generate_mpspdz_circuit, generate_mpspdz_inputs_for_party, run_mpspdz_circuit, tensors_to_circom_mpspdz_inputs
CIRCOM_2_ARITHC_PROJECT_ROOT = Path('/path/to/circom-2-arithc-project-root')
MP_SPDZ_PROJECT_ROOT = Path('/path/to/mp-spdz-project-root')
CIRCOM_2_ARITHC_PROJECT_ROOT = Path('/Users/mhchia/projects/work/pse/circom-2-arithc')
MP_SPDZ_PROJECT_ROOT = Path('/Users/mhchia/projects/work/pse/MP-SPDZ')
# For generalized multiple tensor input
def compile_and_run_mpspdz(model_type: Type[nn.Module], data: tuple[torch.Tensor], tmp_path: Path):
def compile_and_run_mpspdz(model_type: Type[nn.Module], data: tuple[torch.Tensor, ...], tmp_path: Path):
# output_path = tmp_path
# Don't use tmp_path for now for easier debugging
# So you should see all generated files in `output_path`
@@ -161,7 +161,7 @@ def run_torch_model(model_type: Type[nn.Module], data: tuple[torch.Tensor]) -> t
return output_torch.reshape(-1)
def torch_model_to_onnx(model_type: Type[nn.Module], data: tuple[torch.Tensor], output_onnx_path: Path):
def torch_model_to_onnx(model_type: Type[nn.Module], data: tuple[torch.Tensor, ...], output_onnx_path: Path):
model = model_type()
input_names = []
# dynamic_axes = {}

View File

@@ -135,6 +135,19 @@ class IState(ABC):
"""
...
@abstractmethod
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], join_index: int) -> list[torch.Tensor]:
"""
Join two columns based on the join index. Two columns The behavior should conform to
[pandas.DataFrame.join](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.join.html) in pandas library.
:param columns_0: A list of tensors
:param columns_1: A list of tensors
:param join_index: The index to join on
:return: A list of tensors
"""
...
class State(IState):
"""
@@ -202,6 +215,9 @@ class State(IState):
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
return torch.where(_filter, x, x-x+MagicNumber)
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], join_index: int) -> list[torch.Tensor]:
raise NotImplementedError
def _call_op(self, x: list[torch.Tensor], op_type: Type[Operation]) -> Union[torch.Tensor, tuple[IsResultPrecise, torch.Tensor]]:
if self.current_op_index is None:
# for prover
@@ -414,6 +430,74 @@ class MPCState(IState):
def where(self, _filter: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
x_key = columns_0[index_0]
y_key = columns_1[index_1]
num_rows_x = x_key.size(0)
num_cols_y = len(columns_1)
# Create a tensor for each new y columns
new_y = [torch.where(x_key == 0, 0, 0) for _ in columns_1]
for i in range(num_rows_x):
# i = 0, one_hot = [1, 0, 0, 0]
# i = 1, one_hot = [0, 1, 0, 0]
one_hot = torch.arange(num_rows_x) == i
# new_y[i] = 1
# [1, 4] -> [0, 1]
mask = y_key == x_key[i]
# is_mask_nonzero = torch.sum(mask) != 0
for k in range(num_cols_y):
# [0, 1] * [5, 4] -> [0, 4]
# sum([0, 4]) -> 4
matched_value = torch.sum(mask * columns_1[k])
# [1, 0, 0, 0] * 4 -> [4, 0, 0, 0]
entry = one_hot * matched_value
new_y[k] = entry + new_y[k]
return list(columns_0) + [torch.where(col == 0, MagicNumber, col) for col in new_y]
# # NOTE: naive way
# def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
# x_key = columns_0[index_0]
# y_key = columns_1[index_1]
# num_rows_x = x_key.size(0)
# num_cols_y = len(columns_1)
# # Create a tensor to hold the new y values
# new_y = [torch.where(x_key, x_key, MagicNumber) for _ in columns_1]
# for i in range(num_rows_x):
# # Create a mask for matching rows
# mask = y_key == x_key[i]
# for k in range(num_cols_y):
# new_value = torch.sum(mask * columns_1[k])
# # Update new_y only where mask is not zero
# is_mask_nonzero = torch.sum(mask) != 0
# new_y[k][i] = torch.where(is_mask_nonzero, new_value, new_y[k][i])
# return list(columns_0) + new_y
# # NOTE: matmul way
# def join(self, columns_0: list[torch.Tensor], columns_1: list[torch.Tensor], index_0: int, index_1: int) -> list[torch.Tensor]:
# x_key = columns_0[index_0]
# y_key = columns_1[index_1]
# num_cols_y = len(columns_1)
# # Create a tensor to hold the new y values
# new_y = [torch.full_like(x_key, MagicNumber) for _ in columns_1]
# # Expand dimensions for broadcasting
# x_key_expanded = x_key.unsqueeze(1) # Shape: [num_rows_x, 1]
# y_key_expanded = y_key.unsqueeze(0) # Shape: [1, num_rows_y]
# # Create a match matrix where x_key == y_key
# match_matrix = (x_key_expanded == y_key_expanded).float() # Shape: [num_rows_x, num_rows_y]
# # Compute new_y for each column in y
# for k in range(num_cols_y):
# matched_y = torch.matmul(match_matrix, columns_1[k].unsqueeze(1)).squeeze(1) # Shape: [num_rows_x]
# new_y[k] = torch.where(matched_y != 0, matched_y, new_y[k])
# return list(columns_0) + new_y
def computation_to_model_mpc(computation: TComputation) -> tuple[State, Type[IModel]]:
state = MPCState()

View File

@@ -5,6 +5,7 @@ import typing
import numpy as np
import keras
import torch
# <KerasTensor shape=(), dtype=float32, sparse=False, name=keras_tensor_10>
@@ -23,10 +24,35 @@ class Input:
name: typing.Optional[str]
# If it's a constant, value is the value of the constant. Else, it's None
value: typing.Optional[float]
# is it keras_tensor in form of no shape i.e. shape = ()
# is it keras_tensor in form of no shape i.e. shape = ()
is_keras_constant: bool
def dict_to_tensor(data):
if data['class_name'] != '__numpy__':
raise ValueError("Unsupported class_name")
value = data['config']['value']
dtype = data['config']['dtype']
# Map the dtype string to a PyTorch dtype
dtype_map = {
'float32': torch.float32,
'float64': torch.float64,
'int32': torch.int32,
'int64': torch.int64
}
if dtype not in dtype_map:
raise ValueError("Unsupported dtype")
tensor_dtype = dtype_map[dtype]
# Convert the list to a PyTorch tensor with the specified dtype
tensor = torch.tensor(value, dtype=tensor_dtype)
return tensor
# read each layer in a model and convert it to a class called Layer
@dataclass
class Layer:
@@ -48,7 +74,7 @@ class Layer:
self.config = _config
self.inputs = []
list_inputs = _config['node_inputs']
index = 0
for ele_name in list_inputs:
@@ -66,21 +92,29 @@ class Layer:
# if it's keras tensor resulting in constant, get the shape from non-constant input
if input_shape == ():
# if there are more than 1 inputs like `TFAdd`, we need to get the shape of the other input
if len(_inputs)==2 and len(_inputs[1-index].shape)>=1:
input_shape = (_inputs[1-index]).shape
else:
input_shape =(1,)
# if len(_inputs)==2 and len(_inputs[1-index].shape)>=1:
# input_shape = (_inputs[1-index]).shape
# else:
input_shape =(1,)
is_keras_constant = True
index += 1
# it's constant. assume it's a float
# FIXME: a constant can be a tensor with multiple dimensions, but for now we assume
# it's constant.
else:
name = None
value = float(config_ele)
if len(_inputs)>0 and len(_inputs[0].shape)>=1:
input_shape = (_inputs[0]).shape
# '/Constant_2_output_0': {'class_name': '__numpy__', 'config': {'value': [1.0, 0.0, 0.0], 'dtype': 'float32'}}
if isinstance(config_ele, dict) and config_ele["class_name"]=='__numpy__':
value = config_ele['config']['value']
value_in_tensor = dict_to_tensor(config_ele)
input_shape = value_in_tensor.shape
# '/Constant_output_0': 0
else:
input_shape =(1,)
value = float(config_ele)
# if len(_inputs)>0 and len(_inputs[0].shape)>=1:
# input_shape = (_inputs[0]).shape
# else:
input_shape = (1,)
self.inputs.append(
Input(
is_constant=not is_non_constant,

View File

@@ -29,8 +29,9 @@ from zkstats.onnx2circom.onnx2keras.layers import (
TFOr,
TFWhere,
TFAbs,
TFGather,
TFConcat
TFGather,
TFConcat,
TFIdentity,
# TFArgMax,
# TFArgMin,
)
@@ -65,8 +66,9 @@ SUPPORTED_OPS = [
TFCast,
TFWhere,
TFAbs,
TFGather,
TFConcat
TFGather,
TFConcat,
TFIdentity,
# TFErf,
]
@@ -100,9 +102,9 @@ def get_component_args_values(layer: Layer) -> typing.Dict[str, typing.Any]:
return {'e': 2, 'nInputs': num_elements_in_input_0}
if is_in_ops(layer.op, [TFReduceSum, TFReduceMean, TFReduceMax, TFReduceMin]):
return {'nInputs': num_elements_in_input_0}
if is_in_ops(layer.op, [TFNot, TFCast, TFWhere, TFAbs, TFGather]):
if is_in_ops(layer.op, [TFNot, TFCast, TFWhere, TFAbs, TFGather, TFIdentity]):
return {'nElements': num_elements_in_input_0}
# 2 inputs operations
# 2 inputs operations
if is_in_ops(layer.op, [TFAdd, TFSub, TFMul, TFDiv, TFEqual, TFGreater, TFLess, TFAnd, TFOr]):
input_1 = inputs[1]
input_1_shape = input_1.shape
@@ -198,13 +200,15 @@ def transpile(templates: dict[str, Template], filename: str, output_dir: str = '
# Handle right hand side when it's keras tensor
input_name = _input.name
input_shape = _input.shape
# This input can only be a constant, a model input, or an output from another component
if _input.is_constant:
# If this input is a constant, use the value of the constant directly as the right hand side
from_component_name = None
# FIXME: if _input.value is a float, we need to scale it by 10^dec
# Scale the float value by 10^dec
scaled = int(_input.value * 10 ** dec)
from_component_signal_name = str(scaled)
rhs_dim = 0
rhs_dim = len(input_shape)
elif model.is_model_input(input_name) is True:
# If this input is from `input_layer`, use the original tensor name for it
from_component_name = None

View File

@@ -281,7 +281,7 @@ template TFGather(nElements) {
}
out <== out_till[nElements];
}
template TFConcat(nElements_0, nElements_1){
@@ -294,4 +294,12 @@ template TFConcat(nElements_0, nElements_1){
for (var i = 0; i< nElements_1; i++){
out[nElements_0+i] <== in_1[i]+0;
}
}
}
template TFIdentity(nElements) {
signal input in[nElements];
signal output out[nElements];
for (var i = 0; i<nElements; i++){
out[i] <== in[i] + 0;
}
}

View File

@@ -2,7 +2,7 @@ import math
import logging
import numpy as np
import tensorflow as tf
from tensorflow import keras
import keras
from ..utils.op_registry import OPERATOR
from .dimension_utils import intfloat_to_list
@@ -171,7 +171,7 @@ class TFUpsample():
def __call__(self, inputs):
return tf.image.resize(inputs, self.scale, method=self.method)
@OPERATOR.register_operator("Constant")
class TFConstant():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
@@ -242,12 +242,26 @@ class TFGemm():
return self.dense(inputs)
@OPERATOR.register_operator("Identity")
class TFIdentity():
def __init__(self, *args, **kwargs):
class TFIdentity(keras.layers.Layer):
def __init__(self,tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
super().__init__()
self.tensor_grap = tensor_grap
self.node_weights = node_weights
self.node_inputs = node_inputs
self.node_attribute = node_attribute
def __call__(self, inputs):
return inputs
def call(self, input, *args, **kwargs):
return input
def get_config(self):
config = super().get_config()
config.update({
"tensor_grap":self.tensor_grap,
'node_weights':self.node_weights,
'node_inputs':self.node_inputs,
'node_attribute':self.node_attribute,
})
return config
@OPERATOR.register_operator("Dropout")
class TFDropout():
@@ -310,14 +324,14 @@ class TFCast(keras.layers.Layer):
# inputs = tf.cast(inputs, dtype=self.tf_cast_map[self.cast_to])
return inputs
def get_config(self):
config = super().get_config()
config.update({
"tensor_grap":self.tensor_grap,
'node_weights':self.node_weights,
'node_inputs':self.node_inputs,
'node_attribute':self.node_attribute,
'node_attribute':self.node_attribute,
"cast_to": self.cast_to
})
return config