mirror of
https://github.com/MPCStats/zk-stats-lib.git
synced 2026-01-08 21:18:04 -05:00
clarify torch.log() issue
This commit is contained in:
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import compile_and_check
|
||||
|
||||
# two tensor stuffs
|
||||
def test_comparison(tmp_path):
|
||||
data_1 = torch.tensor(
|
||||
[32, 8, 8],
|
||||
dtype = torch.float32,
|
||||
).reshape(1, -1, 1)
|
||||
data_2 = torch.tensor(
|
||||
[3, 8, 9],
|
||||
dtype = torch.float32,
|
||||
).reshape(1, -1, 1)
|
||||
class Model(nn.Module):
|
||||
def forward(self, x, y):
|
||||
return torch.logical_or(x<=y, x<y)
|
||||
return torch.logical_and(x<=y, x<y)
|
||||
return torch.logical_not(x<=y)
|
||||
return x>y
|
||||
return x>=y
|
||||
return x<y
|
||||
return x<=y
|
||||
return x==y
|
||||
|
||||
compile_and_check(Model, (data_1, data_2), tmp_path)
|
||||
@@ -3,23 +3,30 @@ import torch.nn as nn
|
||||
|
||||
from .utils import compile_and_run_mpspdz, run_torch_model
|
||||
|
||||
|
||||
def test_onnx_to_circom(tmp_path):
|
||||
data = torch.tensor(
|
||||
[32, 8, 8],
|
||||
dtype = torch.float32,
|
||||
).reshape(1, -1, 1)
|
||||
class Model(nn.Module):
|
||||
class ModelMPSPDZ(nn.Module):
|
||||
def forward(self, x):
|
||||
m = torch.mean(x) # 16
|
||||
s = torch.sum(x) # 48
|
||||
l = torch.log(x) # 5,3,3
|
||||
l = torch.log(x) # 5, 3, 3
|
||||
return m*s+l #773, 771, 771
|
||||
|
||||
class ModelTorch(nn.Module):
|
||||
def forward(self, x):
|
||||
m = torch.mean(x) # 16
|
||||
s = torch.sum(x) # 48
|
||||
l = torch.log2(x) # 5, 3, 3
|
||||
return m*s+l #773, 771, 771
|
||||
# Run the model directly with torch
|
||||
output_torch = run_torch_model(Model, data)
|
||||
# Here cant do that since our torch.log() is 2-based, while actual torch.log() is e-based
|
||||
# Will resolve once we support scaling to support floatin constant
|
||||
output_torch = run_torch_model(ModelTorch, data)
|
||||
# Compile and run the model with MP-SPDZ
|
||||
outputs_mpspdz = compile_and_run_mpspdz(Model, data, tmp_path)
|
||||
outputs_mpspdz = compile_and_run_mpspdz(ModelMPSPDZ, data, tmp_path)
|
||||
# The model only has one output tensor
|
||||
assert len(outputs_mpspdz) == 1, f"Expecting only one output tensor, but got {len(outputs_mpspdz)} tensors."
|
||||
# Compare the output tensor with the expected output. Should be close
|
||||
|
||||
@@ -45,40 +45,51 @@ def test_two_inputs(func, tmp_path):
|
||||
# Compare the output tensor with the expected output. Different should be within 0.001
|
||||
assert torch.allclose(outputs_mpspdz[0], output_torch, rtol=0.001), f"Output tensor is not close to the expected output tensor. {outputs_mpspdz[0]=}, {output_torch=}"
|
||||
|
||||
|
||||
def log(x, base=None):
|
||||
# if base is None, we use natural logarithm
|
||||
if base is None:
|
||||
return torch.log(x)
|
||||
# else, convert to `base` by `log_base(x) = log_k(x) / log_k(base)`
|
||||
return torch.log(x) / torch.log(torch.tensor(float(base)))
|
||||
# Not use until we support scaling for floating number
|
||||
# def log(x, base=None):
|
||||
# # if base is None, we use natural logarithm
|
||||
# if base is None:
|
||||
# return torch.log(x)
|
||||
# # else, convert to `base` by `log_base(x) = log_k(x) / log_k(base)`
|
||||
# return torch.log(x) / torch.log(torch.tensor(float(base)))
|
||||
|
||||
|
||||
# @pytest.mark.parametrize(
|
||||
# "func",
|
||||
# [
|
||||
# # pytest.param(lambda x, base: x + log(x, base), id="x + log(x)"),
|
||||
# # pytest.param(lambda x, base: log(x, base) + x, id="log(x) + x"),
|
||||
# # pytest.param(lambda x, base: x - log(x, base), id="x - log(x)"),
|
||||
# # pytest.param(lambda x, base: log(x, base) - x, id="log(x) - x"),
|
||||
# # pytest.param(lambda x, base: torch.mean(x) + log(x, base), id="mean(x) + log(x)"),
|
||||
# # pytest.param(lambda x, base: log(x, base) + torch.mean(x), id="log(x) + mean(x)"),
|
||||
# # pytest.param(lambda x, base: torch.mean(x) - log(x, base), id="mean(x) - log(x)"),
|
||||
# ]
|
||||
# )
|
||||
|
||||
# FIXME: Now our circom interprets torch.log as base 2, while torch interprets as base e, to make things coherent, we enforce
|
||||
# func_torch to be torch.log2. We can use log base e in circom once we support scaling.
|
||||
@pytest.mark.parametrize(
|
||||
"func",
|
||||
"func_mpspdz, func_torch",
|
||||
[
|
||||
pytest.param(lambda x, base: x + log(x, base), id="x + log(x)"),
|
||||
pytest.param(lambda x, base: log(x, base) + x, id="log(x) + x"),
|
||||
pytest.param(lambda x, base: x - log(x, base), id="x - log(x)"),
|
||||
pytest.param(lambda x, base: log(x, base) - x, id="log(x) - x"),
|
||||
pytest.param(lambda x, base: torch.mean(x) + log(x, base), id="mean(x) + log(x)"),
|
||||
pytest.param(lambda x, base: log(x, base) + torch.mean(x), id="log(x) + mean(x)"),
|
||||
pytest.param(lambda x, base: torch.mean(x) - log(x, base), id="mean(x) - log(x)"),
|
||||
pytest.param(lambda x: x + torch.log(x), lambda x: x + torch.log2(x), id="x + log(x)"),
|
||||
pytest.param(lambda x: torch.log(x) + x, lambda x: torch.log2(x) + x, id="log(x) + x"),
|
||||
pytest.param(lambda x: x - torch.log(x), lambda x: x - torch.log2(x), id="x - log(x)"),
|
||||
pytest.param(lambda x: torch.log(x) - x, lambda x: torch.log2(x) - x, id="log(x) - x"),
|
||||
pytest.param(lambda x: torch.mean(x) + torch.log(x), lambda x: torch.mean(x) + torch.log2(x), id="mean(x) + log(x)"),
|
||||
pytest.param(lambda x: torch.log(x) + torch.mean(x), lambda x: torch.log2(x) + torch.mean(x), id="log(x) + mean(x)"),
|
||||
pytest.param(lambda x: torch.mean(x) - torch.log(x), lambda x: torch.mean(x) - torch.log2(x), id="mean(x) - log(x)"),
|
||||
]
|
||||
)
|
||||
def test_two_inputs_with_logs(func, tmp_path):
|
||||
e = 2.7183
|
||||
def test_two_inputs_with_logs(func_mpspdz,func_torch, tmp_path):
|
||||
data = torch.tensor(
|
||||
[32, 8, 8],
|
||||
dtype = torch.float32,
|
||||
).reshape(1, -1, 1)
|
||||
|
||||
class ModelMPSPDZ(nn.Module):
|
||||
def forward(self, x):
|
||||
# FIXME: We should remove `log` and use `torch.log` directly when
|
||||
# we support base=e.
|
||||
# Now we need to convert base to `e` since we currently use `base=2` in mp-spdz
|
||||
# to calculate ln(x) = log_2(x) / log_2(e)
|
||||
return func(x, base=e)
|
||||
return func_mpspdz(x)
|
||||
|
||||
outputs_tensor_mpsdpz = compile_and_run_mpspdz(ModelMPSPDZ, data, tmp_path)
|
||||
# The model only has one output tensor
|
||||
@@ -87,8 +98,7 @@ def test_two_inputs_with_logs(func, tmp_path):
|
||||
|
||||
class ModelTorch(nn.Module):
|
||||
def forward(self, x):
|
||||
# base = None means we don't need the conversion at all since torch uses e by default
|
||||
return func(x, base=None)
|
||||
return func_torch(x)
|
||||
|
||||
output_torch = run_torch_model(ModelTorch, data)
|
||||
assert output_mpspdz.shape == output_torch.shape, f"Output tensor shape is not the same. {output_mpspdz.shape=}, {output_torch.shape=}"
|
||||
|
||||
Reference in New Issue
Block a user