clarify torch.log() issue

This commit is contained in:
JernKunpittaya
2024-05-29 14:19:08 +07:00
parent 7e956cb3f5
commit 488fa61b8d
3 changed files with 46 additions and 56 deletions

View File

@@ -1,27 +0,0 @@
import torch
import torch.nn as nn
from .utils import compile_and_check
# two tensor stuffs
def test_comparison(tmp_path):
data_1 = torch.tensor(
[32, 8, 8],
dtype = torch.float32,
).reshape(1, -1, 1)
data_2 = torch.tensor(
[3, 8, 9],
dtype = torch.float32,
).reshape(1, -1, 1)
class Model(nn.Module):
def forward(self, x, y):
return torch.logical_or(x<=y, x<y)
return torch.logical_and(x<=y, x<y)
return torch.logical_not(x<=y)
return x>y
return x>=y
return x<y
return x<=y
return x==y
compile_and_check(Model, (data_1, data_2), tmp_path)

View File

@@ -3,23 +3,30 @@ import torch.nn as nn
from .utils import compile_and_run_mpspdz, run_torch_model
def test_onnx_to_circom(tmp_path):
data = torch.tensor(
[32, 8, 8],
dtype = torch.float32,
).reshape(1, -1, 1)
class Model(nn.Module):
class ModelMPSPDZ(nn.Module):
def forward(self, x):
m = torch.mean(x) # 16
s = torch.sum(x) # 48
l = torch.log(x) # 5,3,3
l = torch.log(x) # 5, 3, 3
return m*s+l #773, 771, 771
class ModelTorch(nn.Module):
def forward(self, x):
m = torch.mean(x) # 16
s = torch.sum(x) # 48
l = torch.log2(x) # 5, 3, 3
return m*s+l #773, 771, 771
# Run the model directly with torch
output_torch = run_torch_model(Model, data)
# Here cant do that since our torch.log() is 2-based, while actual torch.log() is e-based
# Will resolve once we support scaling to support floatin constant
output_torch = run_torch_model(ModelTorch, data)
# Compile and run the model with MP-SPDZ
outputs_mpspdz = compile_and_run_mpspdz(Model, data, tmp_path)
outputs_mpspdz = compile_and_run_mpspdz(ModelMPSPDZ, data, tmp_path)
# The model only has one output tensor
assert len(outputs_mpspdz) == 1, f"Expecting only one output tensor, but got {len(outputs_mpspdz)} tensors."
# Compare the output tensor with the expected output. Should be close

View File

@@ -45,40 +45,51 @@ def test_two_inputs(func, tmp_path):
# Compare the output tensor with the expected output. Different should be within 0.001
assert torch.allclose(outputs_mpspdz[0], output_torch, rtol=0.001), f"Output tensor is not close to the expected output tensor. {outputs_mpspdz[0]=}, {output_torch=}"
def log(x, base=None):
# if base is None, we use natural logarithm
if base is None:
return torch.log(x)
# else, convert to `base` by `log_base(x) = log_k(x) / log_k(base)`
return torch.log(x) / torch.log(torch.tensor(float(base)))
# Not use until we support scaling for floating number
# def log(x, base=None):
# # if base is None, we use natural logarithm
# if base is None:
# return torch.log(x)
# # else, convert to `base` by `log_base(x) = log_k(x) / log_k(base)`
# return torch.log(x) / torch.log(torch.tensor(float(base)))
# @pytest.mark.parametrize(
# "func",
# [
# # pytest.param(lambda x, base: x + log(x, base), id="x + log(x)"),
# # pytest.param(lambda x, base: log(x, base) + x, id="log(x) + x"),
# # pytest.param(lambda x, base: x - log(x, base), id="x - log(x)"),
# # pytest.param(lambda x, base: log(x, base) - x, id="log(x) - x"),
# # pytest.param(lambda x, base: torch.mean(x) + log(x, base), id="mean(x) + log(x)"),
# # pytest.param(lambda x, base: log(x, base) + torch.mean(x), id="log(x) + mean(x)"),
# # pytest.param(lambda x, base: torch.mean(x) - log(x, base), id="mean(x) - log(x)"),
# ]
# )
# FIXME: Now our circom interprets torch.log as base 2, while torch interprets as base e, to make things coherent, we enforce
# func_torch to be torch.log2. We can use log base e in circom once we support scaling.
@pytest.mark.parametrize(
"func",
"func_mpspdz, func_torch",
[
pytest.param(lambda x, base: x + log(x, base), id="x + log(x)"),
pytest.param(lambda x, base: log(x, base) + x, id="log(x) + x"),
pytest.param(lambda x, base: x - log(x, base), id="x - log(x)"),
pytest.param(lambda x, base: log(x, base) - x, id="log(x) - x"),
pytest.param(lambda x, base: torch.mean(x) + log(x, base), id="mean(x) + log(x)"),
pytest.param(lambda x, base: log(x, base) + torch.mean(x), id="log(x) + mean(x)"),
pytest.param(lambda x, base: torch.mean(x) - log(x, base), id="mean(x) - log(x)"),
pytest.param(lambda x: x + torch.log(x), lambda x: x + torch.log2(x), id="x + log(x)"),
pytest.param(lambda x: torch.log(x) + x, lambda x: torch.log2(x) + x, id="log(x) + x"),
pytest.param(lambda x: x - torch.log(x), lambda x: x - torch.log2(x), id="x - log(x)"),
pytest.param(lambda x: torch.log(x) - x, lambda x: torch.log2(x) - x, id="log(x) - x"),
pytest.param(lambda x: torch.mean(x) + torch.log(x), lambda x: torch.mean(x) + torch.log2(x), id="mean(x) + log(x)"),
pytest.param(lambda x: torch.log(x) + torch.mean(x), lambda x: torch.log2(x) + torch.mean(x), id="log(x) + mean(x)"),
pytest.param(lambda x: torch.mean(x) - torch.log(x), lambda x: torch.mean(x) - torch.log2(x), id="mean(x) - log(x)"),
]
)
def test_two_inputs_with_logs(func, tmp_path):
e = 2.7183
def test_two_inputs_with_logs(func_mpspdz,func_torch, tmp_path):
data = torch.tensor(
[32, 8, 8],
dtype = torch.float32,
).reshape(1, -1, 1)
class ModelMPSPDZ(nn.Module):
def forward(self, x):
# FIXME: We should remove `log` and use `torch.log` directly when
# we support base=e.
# Now we need to convert base to `e` since we currently use `base=2` in mp-spdz
# to calculate ln(x) = log_2(x) / log_2(e)
return func(x, base=e)
return func_mpspdz(x)
outputs_tensor_mpsdpz = compile_and_run_mpspdz(ModelMPSPDZ, data, tmp_path)
# The model only has one output tensor
@@ -87,8 +98,7 @@ def test_two_inputs_with_logs(func, tmp_path):
class ModelTorch(nn.Module):
def forward(self, x):
# base = None means we don't need the conversion at all since torch uses e by default
return func(x, base=None)
return func_torch(x)
output_torch = run_torch_model(ModelTorch, data)
assert output_mpspdz.shape == output_torch.shape, f"Output tensor shape is not the same. {output_mpspdz.shape=}, {output_torch.shape=}"