Initial commit

This commit is contained in:
Vikram Saraph
2025-01-09 11:26:56 -05:00
committed by GitHub
commit 7fbe5cef5c
36 changed files with 5983 additions and 0 deletions

31
LICENSE Normal file
View File

@@ -0,0 +1,31 @@
(c) 2021-2023 The Johns Hopkins University Applied Physics
Laboratory LLC (JHU/APL).
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
OF THE POSSIBILITY OF SUCH DAMAGE.

21
LICENSES/Optuna/LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2018 Preferred Networks, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

80
LICENSES/PyTorch/LICENSE Normal file
View File

@@ -0,0 +1,80 @@
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

76
README.md Normal file
View File

@@ -0,0 +1,76 @@
# SHIELD: Secure Homomorphic Inference for Encrypted Learning on Data
SHIELD is a library for evalating pre-trained convolutional neural networks on homomorphically encrypted images. It includes code for training models that are suitable for homomorphic evaluation. Implemented neural network operations include convolution, average pooling, GELU, and linear layers.
This code was used to run the experiments supporting the following paper: [High-Resolution Convolutional Neural Networks on Homomorphically Encrypted Data via Sharding Ciphertexts
](https://arxiv.org/abs/2306.09189). However, operators defined in this project are generic enough to build arbitrary convolutional neural networks as specified in the paper.
## Requirements
This project's dependencies are managed by Poetry, so installing [Poetry](https://python-poetry.org/) is a requirement. OpenFHE Python bindings are used to interface with OpenFHE, so the wheel file for these bindings will also need to be built. See the OpenFHE Python bindings repository for further instructions.
Once the bindings are builts, ensure that the `pyproject.toml` file contains a correct path to the bindings. Then to install the Python environment for this project, run `poetry install`. For running unit tests and the small neural network as described below, 32GB of RAM is recommended. For hardware requirements needed to reproduce results for the larger ResNet architecures, see the paper for details.
Code was developed and tested on Ubuntu 20.04. While it should run on Windows platforms as well, this has not been explicitly tested.
## Features
SHIELD implements the following neural network operators:
- Convolution
- Average pooling
- Batch normalization (which are fused with convolution operators for performance)
- Linear
- GELU (Gaussian Error Linear Unit, a smooth alternative to ReLU)
- Upsample
For performance reasons, the core of these algorithms are mostly implemented in the companion OpenFHE Python bindings project (in C++), with this project providing a minimal but more user-friendly Python inference for using them.
The following neural network architectures are implemented using homomorphic implementations of the above operators: a neural network consisting of three convolution blocks (mainly for integration testing), and variations on ResNet including ResNet9 and ResNet50. In addition, code for training models suitable for homomorphic evaluation, using these architectures is included. Training code includes kurtosis regularization required for homomorphic inference. See the referenced paper for more details on the algorithms implemented, as well as performance metrics for homomorphic inference using these neural networks.
## Running the code
### Units tests
Tests are run with `pytest`:
```
poetry run python palisade_he_cnn/test.py
```
### A small neural network
`small_model.py` includes code defining a 3-layer convolutional neural network, as well as code to train a model, on MNIST, instantiated from this network. The training code can be run with:
```
poetry run python palisade_he_cnn/src/small_model.py
```
This will save model weights to `small_model.pt`. To run homomorphic inference with these weights, move the weights to `palisade_he_cnn/src/weights/` and then run:
```
poetry run python palisade_he_cnn/src/small_model_inference.py
```
This script builds an equivalent homomorphic architecture, extracting weights from the plaintext model, and runs inference on MNIST. It prints out inference times to the terminal. For convenience, example weights are already included in `palisade_he_cnn/src/weights`.
### Larger neural networks
Scripts to train larger models are included in `palisade_he_cnn/training`. Scripts that run inference with these models are in `palisade_he_cnn/inference`. Due to significant resources required to train and run homomorphic inference with these larger models, weights used in the paper will be added to this repository in the future.
## Citation and Acknowledgements
Please cite this work as follows:
```
@misc{maloney2024highresolutionconvolutionalneuralnetworks,
title={High-Resolution Convolutional Neural Networks on Homomorphically Encrypted Data via Sharding Ciphertexts},
author={Vivian Maloney and Richard F. Obrecht and Vikram Saraph and Prathibha Rama and Kate Tallaksen},
year={2024},
eprint={2306.09189},
archivePrefix={arXiv},
primaryClass={cs.CR},
url={https://arxiv.org/abs/2306.09189},
}
```
In addition to the authors on the supporting manuscript (Vivian Maloney, Freddy Obrect, Vikram Saraph, Prathibha Rama, and Kate Tallaksen), Lindsay Spriggs and Court Climer also contributed to this work by testing the software and integrating it with internal infrastructure.

View File

@@ -0,0 +1,132 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import argparse
import copy
import json
from time import time
import torch
import torchvision
import torchvision.transforms as transforms
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
from palisade_he_cnn.src.he_cnn.utils import *
from palisade_he_cnn.src.utils import pad_conv_input_channels, PadChannel
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--idx", default="0")
args = vars(parser.parse_args())
img_idx = int(args["idx"])
print("img_idx", img_idx)
# create HE cc and keys
mult_depth = 35
scale_factor_bits = 59
batch_size = 32 * 32 * 32
# if using bootstrapping, you must increase scale_factor_bits to 59
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
stats = ((0.4914, 0.4822, 0.4465), # mean
(0.247, 0.243, 0.261)) # std
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(*stats,inplace=True),
PadChannel(npad=1),
transforms.Resize(32)
])
validset = torchvision.datasets.CIFAR10(root="./data", download=True, transform=transform)
validloader = torch.utils.data.DataLoader(validset, batch_size=1, shuffle=True)
# top level model
resnet_model = torch.load("palisade_he_cnn/src/weights/resnet50_cifar_gelu_kurt.pt")
resnet_model.eval()
print(resnet_model)
##############################################################################
conv1 = resnet_model.conv1
bn1 = resnet_model.bn1
padded_conv1 = pad_conv_input_channels(conv1)
embedder = copy.deepcopy(torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
for i, (padded_test_data, test_label) in enumerate(validloader):
if i == img_idx:
break
unpadded_test_data = padded_test_data[:,:3]
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
##############################################################################
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
start = time()
# embedding layer
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
cnn_context = cnn_context.apply_gelu(bound=15.0)
unencrypted = ptxt_embedded
compare_accuracy(keys, cnn_context, unencrypted, "embedding", num_digits=7)
###############################################################################
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
for j, bottleneck in enumerate(layer):
bootstrap = False if (i == 0 and j == 0) else True
name = f"bottleneck #{i+1}-{j}"
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=bootstrap, bootstrap_params={"meta" : True})
unencrypted = bottleneck(unencrypted)
compare_accuracy(keys, cnn_context, unencrypted, name, num_digits=7)
###############################################################################
linear = resnet_model.fc
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
inference_time = time() - start
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
ptxt_logits = linear(flattened)
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
print(f"[+] decrypted logits = {decrypted_logits}")
print(f"[+] plaintext logits = {ptxt_logits}")
###############################################################################
dataset = "cifar10"
model_type = "resnet50_metaBTS"
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
filename.parent.mkdir(exist_ok=True, parents=True)
data = dict(TIMING_DICT)
data["decrypted logits"] = decrypted_logits.tolist()
data["unencrypted logits"] = ptxt_logits.tolist()
data["inference time"] = inference_time
# avoid double-counting the strided conv operations
data['Pool'] = data['Pool'][1:2]
with open(filename, "w") as f:
json.dump(data, f, indent=4)

View File

@@ -0,0 +1,159 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
# srun -p hybrid -n 128 --mem=300G --pty bash -i
# srun -p himem -n 128 --mem=300G --pty bash -i
# export OMP_DISPLAY_ENV=TRUE
# export OMP_NUM_THREADS=32
import torch
import numpy as np
from time import time
import copy
import json
import argparse
from pathlib import Path
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from palisade_he_cnn.src.he_cnn.utils import compare_accuracy, get_keys
from palisade_he_cnn.src.utils import pad_conv_input_channels
from palisade_he_cnn.training.utils.utils import PadChannel
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--idx", default="0")
args = vars(parser.parse_args())
img_idx = int(args["idx"])
print("img_idx", img_idx)
IMAGENET_CHANNEL_MEAN = (0.485, 0.456, 0.406)
IMAGENET_CHANNEL_STD = (0.229, 0.224, 0.225)
stats = (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD)
IMAGENET_DIR = Path("/aoscluster/he-cnn/vivian/imagenet/datasets/ILSVRC/Data/CLS-LOC")
resize_size = 136
crop_size = 128
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(*stats,inplace=True),
PadChannel(npad=1),
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size)
])
validset = ImageFolder(IMAGENET_DIR / "val", transform=transform)
validloader = DataLoader(validset,
batch_size = 1,
pin_memory = True,
num_workers = 1,
shuffle=True)
# top level model
resnet_model = torch.load("weights/resnet50_imagenet128_gelu_kurt.pt")
resnet_model.eval()
print(resnet_model)
##############################################################################
conv1 = resnet_model.conv1
bn1 = resnet_model.bn1
padded_conv1 = pad_conv_input_channels(conv1)
embedder = copy.deepcopy(torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
for i, (padded_test_data, test_label) in enumerate(validloader):
if i == img_idx:
break
unpadded_test_data = padded_test_data[:,:3]
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
unencrypted = ptxt_embedded
##############################################################################
# create HE cc and keys
mult_depth = 34
scale_factor_bits = 59
batch_size = 32 * 32 * 32
# if using bootstrapping, you must increase scale_factor_bits to 59
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
##############################################################################
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
while cnn_context.shards[0].getTowersRemaining() > 18:
for i in range(cnn_context.num_shards):
cnn_context.shards[i] *= 1.0
start = time()
# embedding layer
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
cnn_context = cnn_context.apply_gelu(bound=50.0, degree=200)
cnn_context = cnn_context.apply_pool(conv=True)
compare_accuracy(keys, cnn_context, unencrypted, "embedding")
###############################################################################
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
for j, bottleneck in enumerate(layer):
name = f"bottleneck #{i+1}-{j}"
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=True, gelu_params={"bound" : 15.0, "degree": 59})
unencrypted = bottleneck(unencrypted)
compare_accuracy(keys, cnn_context, unencrypted, name)
###############################################################################
linear = resnet_model.fc
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
inference_time = time() - start
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
ptxt_logits = linear(flattened)
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
print(f"[+] decrypted logits = {decrypted_logits}")
print(f"[+] plaintext logits = {ptxt_logits}")
###############################################################################
dataset = "imagenet"
model_type = "resnet50_128"
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
filename.parent.mkdir(exist_ok=True, parents=True)
data = dict(TIMING_DICT)
data["decrypted logits"] = decrypted_logits.tolist()
data["unencrypted logits"] = ptxt_logits.tolist()
data["inference time"] = inference_time
# avoid double-counting the strided conv operations
data['Pool'] = data['Pool'][:1]
with open(filename, "w") as f:
json.dump(data, f, indent=4)

View File

@@ -0,0 +1,155 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
# srun -p hybrid -n 128 --mem=300G --pty bash -i
# srun -p himem -n 128 --mem=300G --pty bash -i
# export OMP_DISPLAY_ENV=TRUE
# export OMP_NUM_THREADS=32
import argparse
import copy
import json
from time import time
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
from palisade_he_cnn.src.cnn_context.utils import *
from palisade_he_cnn.src.he_cnn.utils import get_keys, compare_accuracy
from palisade_he_cnn.src.utils import pad_conv_input_channels
from palisade_he_cnn.training.utils.utils import PadChannel
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--idx", default="0")
args = vars(parser.parse_args())
img_idx = int(args["idx"])
print("img_idx", img_idx)
IMAGENET_CHANNEL_MEAN = (0.485, 0.456, 0.406)
IMAGENET_CHANNEL_STD = (0.229, 0.224, 0.225)
stats = (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD)
IMAGENET_DIR = Path("/aoscluster/he-cnn/vivian/imagenet/datasets/ILSVRC/Data/CLS-LOC")
resize_size = 264
crop_size = 256
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(*stats, inplace=True),
PadChannel(npad=1),
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size)
])
validset = ImageFolder(IMAGENET_DIR / "val", transform=transform)
validloader = DataLoader(validset,
batch_size=1,
pin_memory=True,
num_workers=1,
shuffle=True)
# top level model
resnet_model = torch.load("weights/resnet50_imagenet256_gelu_kurt.pt")
resnet_model.eval()
print(resnet_model)
##############################################################################
conv1 = resnet_model.conv1
bn1 = resnet_model.bn1
padded_conv1 = pad_conv_input_channels(conv1)
embedder = copy.deepcopy(
torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
for i, (padded_test_data, test_label) in enumerate(validloader):
if i == img_idx:
break
unpadded_test_data = padded_test_data[:, :3]
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
unencrypted = ptxt_embedded
##############################################################################
# create HE cc and keys
mult_depth = 34
scale_factor_bits = 59
batch_size = 32 * 32 * 32
# if using bootstrapping, you must increase scale_factor_bits to 59
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
##############################################################################
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
while cnn_context.shards[0].getTowersRemaining() > 18:
for i in range(cnn_context.num_shards):
cnn_context.shards[i] *= 1.0
start = time()
# embedding layer
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
cnn_context = cnn_context.apply_gelu(bound=50.0, degree=200)
cnn_context = cnn_context.apply_pool(conv=True)
compare_accuracy(keys, cnn_context, unencrypted, "embedding")
###############################################################################
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
for j, bottleneck in enumerate(layer):
name = f"bottleneck #{i + 1}-{j}"
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=True,
gelu_params={"bound": 15.0, "degree": 59})
unencrypted = bottleneck(unencrypted)
compare_accuracy(keys, cnn_context, unencrypted, name)
###############################################################################
linear = resnet_model.fc
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
inference_time = time() - start
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
ptxt_logits = linear(flattened)
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
print(f"[+] decrypted logits = {decrypted_logits}")
print(f"[+] plaintext logits = {ptxt_logits}")
###############################################################################
dataset = "imagenet"
model_type = "resnet50_256"
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
filename.parent.mkdir(exist_ok=True, parents=True)
data = dict(TIMING_DICT)
data["decrypted logits"] = decrypted_logits.tolist()
data["unencrypted logits"] = ptxt_logits.tolist()
data["inference time"] = inference_time
# avoid double-counting the strided conv operations
data['Pool'] = data['Pool'][:1]
with open(filename, "w") as f:
json.dump(data, f, indent=4)

View File

@@ -0,0 +1,17 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
#!/bin/bash
# srun -p hybrid -n 128 --mem=800G --pty bash -i
# ./run_vivian.sh 2>&1 >> logs/imagenet/resnet50_256_log_v2.txt
export OMP_DISPLAY_ENV=TRUE
export OMP_NUM_THREADS=64
export OMP_PROC_BIND=TRUE
for i in `seq 0 50` ; do
python resnet50_cifar_inference.py -i $i
# python resnet50_imagenet128_inference.py -i $i
# python resnet50_imagenet256_inference.py -i $i
done

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,340 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import numpy as np
import math
import torch
from time import time
from collections import defaultdict
import palisade_he_cnn.src.he_cnn.utils as utils
import palisade_he_cnn.src.he_cnn.conv as conv
import palisade_he_cnn.src.he_cnn.pool as pool
import palisade_he_cnn.src.he_cnn.linear as linear
from pyOpenFHE import CKKS as pal
TIMING_DICT = defaultdict(list)
DUPLICATED, IMAGE_SHARDED, CHANNEL_SHARDED = range(3)
def reset_timing_dict():
global TIMING_DICT
TIMING_DICT.clear()
# an image is a PyTorch 3-tensor
def create_cnn_context(image, cc, publicKey, verbose=False):
# create these to encrypt
shard_size = cc.getBatchSize()
if len(image.shape) != 3:
raise ValueError("Input image must be a PyTorch 3-tensor")
# do we want to address rectangular images at some point...?
if image.shape[1] != image.shape[2]:
raise ValueError("Non-square channels not currently supported")
if not utils.is_power_of_2(image.shape[0]):
raise ValueError("Number of channels must be a power-of-two")
if not utils.is_power_of_2(image.shape[1]):
raise ValueError("Image dimensions must be a power-of-two")
mtx_size = image.shape[1]
num_channels = image.shape[0]
total_size = mtx_size * mtx_size * num_channels
num_shards = math.ceil(total_size / shard_size)
if total_size <= shard_size:
duplication_factor = shard_size // total_size
else:
duplication_factor = 1
duplicated_image = np.repeat(image.numpy(), duplication_factor, axis=0).flatten()
shards = []
for s in range(num_shards):
shard = cc.encrypt(publicKey, duplicated_image[shard_size * s: shard_size * (s + 1)])
shards.append(shard)
# return the cc and keys as well for decryption at the end
cnn_context = CNNContext(shards, mtx_size, num_channels, permutation=None, verbose=verbose)
return cnn_context
def timing_decorator_factory(prefix=""):
def timing_decorator(func):
def wrapper_function(*args, **kwargs):
global TIMING_DICT
start = time()
res = func(*args, **kwargs)
layer_time = time() - start
self = args[0]
TIMING_DICT[prefix.strip()].append(layer_time)
if self.verbose:
print(prefix + f"Layer took {layer_time:.02f} seconds")
return res
return wrapper_function
return timing_decorator
class CNNContext:
r"""This class contains methods for applying network layers to an image."""
def __init__(self, shards, mtx_size, num_channels, permutation=None, verbose=False):
r"""Initializes the CNNContext object. We only needs shards and channel/matrix size to compute all other metadata."""
if permutation is None:
permutation = np.array(range(num_channels))
self.shards = shards
self.mtx_size = mtx_size
self.num_channels = num_channels
self.permutation = permutation
self.verbose = verbose
self.compute_metadata()
def compute_metadata(self):
# Shard information
self.num_shards = len(self.shards)
self.shard_size = self.shards[0].getBatchSize()
self.total_size = self.num_shards * self.shard_size
# Channel information
self.channel_size = self.mtx_size * self.mtx_size
# Duplication factor
self.duplication_factor = (self.total_size // self.channel_size) // self.num_channels
if self.duplication_factor > 1:
self.shard_type = DUPLICATED
elif self.channel_size <= self.shard_size:
self.shard_type = IMAGE_SHARDED
else:
self.shard_type = CHANNEL_SHARDED
# Channel and shard info
self.num_phys_chan_per_shard = self.shard_size // self.channel_size
self.num_phys_chan_total = self.num_shards * self.num_phys_chan_per_shard
self.num_log_chan_per_shard = self.num_phys_chan_per_shard // self.duplication_factor
self.num_log_chan_total = self.num_shards * self.num_log_chan_per_shard
def print_metadata(self):
# shard information
print(f"num_shards: {self.num_shards}")
print(f"shard_size: {self.shard_size}")
print(f"total_size: {self.total_size}")
# Channel information
print(f"channel_size: {self.channel_size}")
# Duplication factor
print(f"duplication_factor: {self.duplication_factor}")
print(f"shard_type: {self.shard_type}")
# Channel and shard info
print(f"num_phys_chan_per_shard: {self.num_phys_chan_per_shard}")
print(f"num_phys_chan_total: {self.num_phys_chan_total}")
print(f"num_log_chan_per_shard: {self.num_log_chan_per_shard}")
print(f"num_log_chan_total: {self.num_log_chan_total}")
def decrypt_to_tensor(self, cc, keys):
# decrypt the shards
decrypted_shards = [cc.decrypt(keys.secretKey, shard) for shard in self.shards]
decrypted_output = np.concatenate(decrypted_shards)
# reshape with possible duplication
duplicated_output = decrypted_output.reshape(
self.num_channels * self.duplication_factor,
self.mtx_size,
self.mtx_size
)
decrypted_deduplicated_output = duplicated_output[0 :: self.duplication_factor]
return torch.from_numpy(decrypted_deduplicated_output)
@timing_decorator_factory("Conv ")
def apply_conv(self, conv_layer, bn_layer=None, output_permutation=None, drop_levels=False):
pal.CNN.omp_set_nested(0)
# pal.CNN.omp_set_dynamic(0)
# Get filters, biases
filters, biases = utils.get_filters_and_biases_from_conv2d(conv_layer)
# Get batch norm info if one is passed in
if bn_layer:
scale, shift = utils.get_scale_and_shift_from_bn(bn_layer)
else:
scale = None
shift = None
num_out_channels = filters.shape[1]
if output_permutation is None:
output_permutation = np.array(range(num_out_channels))
elif len(output_permutation) != num_out_channels:
raise ValueError("output permutation is incorrect length")
# TODO this should be a Compress() call
if drop_levels:
L = self.shards[0].getTowersRemaining() - 4
for j in range(self.num_shards):
for i in range(L):
self.shards[j] *= 1.0
# Apply conv
new_shards = conv.conv2d(
ciphertext_shards=self.shards,
filters=filters,
mtx_size=self.mtx_size,
biases=biases,
permutation=self.permutation,
bn_scale=scale,
bn_shift=shift,
output_permutation=output_permutation
)
# Create new CNN Context
stride = conv_layer.stride
cnn_context = CNNContext(new_shards, self.mtx_size, num_out_channels, output_permutation, self.verbose)
if stride == (1, 1):
return cnn_context
elif stride == (2, 2):
return cnn_context.apply_pool(conv=False)
else:
raise ValueError("Unsupported stride: {stride}")
@timing_decorator_factory("Pool ")
def apply_pool(self, conv=True):
pal.CNN.omp_set_nested(0)
# pal.CNN.omp_set_dynamic(0)
# Apply pool
new_shards = pool.pool(self.shards, self.mtx_size, conv)
# Get permutation
new_permutation = pool.get_pool_permutation(self.shards, self.num_channels, self.mtx_size)
new_permutation = pool.compose_permutations(self.permutation, new_permutation)
new_permutation = np.array(new_permutation)
# Create new CNN Context
return CNNContext(new_shards, self.mtx_size // 2, self.num_channels, new_permutation, self.verbose)
@timing_decorator_factory("Fused adaptive pool and linear ")
def apply_fused_pool_linear(self, linear_layer):
has_bias = hasattr(linear_layer, "bias")
return self.apply_linear(linear_layer, has_bias, pool_factor=self.mtx_size)
@timing_decorator_factory("Bottleneck block ")
def apply_bottleneck(self, bottleneck_block, debug=False, gelu_params={}, bootstrap_params={}, bootstrap=True):
# Bottleneck block's forward pass is here: https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
skip_connection = self
downsample_block = bottleneck_block.downsample
if downsample_block:
conv_downsample_layer = downsample_block[0]
bn_downsample_layer = downsample_block[1]
skip_connection = skip_connection.apply_conv(conv_downsample_layer, bn_downsample_layer)
conv1_layer = bottleneck_block.conv1
bn1_layer = bottleneck_block.bn1
cnn_context = self.apply_conv(conv1_layer, bn1_layer)
if not debug:
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
cnn_context = cnn_context.apply_gelu(**gelu_params)
conv2_layer = bottleneck_block.conv2
bn2_layer = bottleneck_block.bn2
cnn_context = cnn_context.apply_conv(conv2_layer, bn2_layer)
if not debug:
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
cnn_context = cnn_context.apply_gelu(**gelu_params)
conv3_layer = bottleneck_block.conv3
bn3_layer = bottleneck_block.bn3
cnn_context = cnn_context.apply_conv(conv3_layer, bn3_layer, output_permutation=skip_connection.permutation)
cnn_context = cnn_context.apply_residual(skip_connection)
if not debug:
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
cnn_context = cnn_context.apply_gelu(**gelu_params)
return cnn_context
# This operation doesn't return a CNNContext, that's returned by linear
@timing_decorator_factory("Linear ")
def apply_linear(self, linear_layer, bias=True, scale=1.0, pool_factor=1):
pal.CNN.omp_set_nested(0)
pal.CNN.omp_set_dynamic(1)
linear_weights, linear_biases = utils.get_weights_and_biases_from_linear(linear_layer,
self.mtx_size,
bias,
pool_factor)
final_shard = linear.linear(self.shards, linear_weights, linear_biases, self.mtx_size, self.permutation, scale,
pool_factor)
return final_shard
@timing_decorator_factory("Square ")
def apply_square(self):
new_shards = [shard * shard for shard in self.shards]
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
@timing_decorator_factory("GELU ")
def apply_gelu(self, bound=10.0, degree=59):
"""
bound:
bound = an upper bound on the absolute value of the inputs.
the polynomial approximation is valid for [-bound, bound]
degree:
degree of Chebyshev polynomial
"""
if self.num_shards < 8:
pal.CNN.omp_set_nested(1)
pal.CNN.omp_set_dynamic(1)
else:
pal.CNN.omp_set_nested(0)
pal.CNN.omp_set_dynamic(1)
# TODO this can be absorbed into the BN
new_shards = [x * (1 / bound) for x in self.shards]
new_shards = pal.CNN.fhe_gelu(new_shards, degree, bound)
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
@timing_decorator_factory("Bootstrapping ")
def apply_bootstrapping(self, meta=False):
if self.num_shards < 8:
pal.CNN.omp_set_nested(1)
pal.CNN.omp_set_dynamic(1)
else:
pal.CNN.omp_set_nested(0)
pal.CNN.omp_set_dynamic(1)
cc = self.shards[0].getCryptoContext()
if meta:
new_shards = cc.evalMetaBootstrap(self.shards)
else:
new_shards = cc.evalBootstrap(self.shards)
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
@timing_decorator_factory("Residual ")
def apply_residual(self, C2):
if len(self.permutation) != len(C2.permutation):
raise ValueError("Incompatible number of channels")
if self.mtx_size != C2.mtx_size:
raise ValueError("Incompatible matrix size")
if any([i != j for i, j in zip(self.permutation, C2.permutation)]):
raise ValueError("Incompatible permutations")
new_shards = [i + j for i, j in zip(self.shards, C2.shards)]
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)

View File

@@ -0,0 +1,166 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import math
from copy import copy
SIN_COEFFS = [
0,
9.99984594193494365437e-01,
0,
-1.66632595072086745320e-01,
0,
8.31238887417884598346e-03,
0,
-1.93162796407356830500e-04,
0,
2.17326217498596729611e-06,
]
COS_COEFFS = [
9.99971094606182687341e-01,
0,
-4.99837602272995734437e-01,
0,
4.15223086250910767516e-02,
0,
-1.34410769349285321733e-03,
0,
1.90652668840074246305e-05,
0,
]
# you technically don't need the . to specify float division in python3
LOG_COEFFS = [
0,
1,
-0.5,
1.0 / 3,
-1.0 / 4,
1.0 / 5,
-1.0 / 6,
1.0 / 7,
-1.0 / 8,
1.0 / 9,
-1.0 / 10,
]
EXP_COEFFS = [
1,
1,
0.5,
1.0 / 6,
1.0 / 24,
1.0 / 120,
1.0 / 720,
1.0 / 5040,
1.0 / 40320,
1.0 / 362880,
1.0 / 3628800,
]
SIGMOID_COEFFS = [
1.0 / 2,
1.0 / 4,
0,
-1.0 / 48,
0,
1.0 / 480,
0,
-17.0 / 80640,
0,
31.0 / 1451520,
0,
]
def powerOf2Extended(cipher, logDegree):
res = [copy(cipher)]
for i in range(logDegree):
t = res[-1]
res.append(t * t)
return res
def powerExtended(cipher, degree):
res = []
logDegree = int(
math.log2(degree)
) # both python and C++ truncate when casting float->int
cpows = powerOf2Extended(cipher, logDegree)
idx = 0
for i in range(logDegree):
powi = pow(2, i)
res.append(cpows[i])
for j in range(powi - 1):
res.append(copy(res[j]))
res[-1] *= cpows[i]
res.append(cpows[logDegree])
degree2 = pow(2, logDegree)
for i in range(degree - degree2):
res.append(copy(res[i]))
res[-1] *= cpows[logDegree]
return res
def polynomial_series_function(cipher, coeffs, verbose=False):
"""
Cipher is a CKKSCiphertext, coeffs should be array-like (generally either native list or numpy array)
"""
degree = len(coeffs)
if verbose:
print("initial ciphertext level = {}".format(cipher.getTowersRemaining()))
cpows = powerExtended(cipher, degree) # array of ciphertexts
# cpows[0] == cipher, i.e. x^1
res = cpows[0] * coeffs[1] # this should be defined
res += coeffs[0]
for i in range(2, degree):
coeff = coeffs[i]
if abs(coeff) > 1e-27:
aixi = cpows[i - 1] * coeff
res += aixi
if verbose:
print("final ciphertext level = {}".format(res.getTowersRemaining()))
return res
"""
example:
to approximate the sine function, do:
polynomial_series_function(c1, SIN_COEFFS)
"""
def sqrt_helper(cipher, steps):
a = copy(cipher)
b = a - 1
for i in range(steps):
a *= 1 - (0.5 * b)
# there must be a better way to do this...
if i < steps - 1:
b = (b * b) * (0.25 * (b - 3))
return a
def sqrt(cipher, steps, upper_bound):
if upper_bound == 1:
return sqrt_helper(cipher, steps)
return sqrt_helper(cipher * (1 / upper_bound), steps) * math.sqrt(upper_bound)
def relu(cipher, steps, upper_bound):
x = cipher * cipher
res = cipher + sqrt(x, steps, upper_bound)
return 0.5 * res

View File

@@ -0,0 +1,56 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import numpy as np
import math
from pyOpenFHE import CKKS as pal
conv2d_cpp = pal.CNN.conv2d
def conv2d(ciphertext_shards, filters, mtx_size, biases, permutation=None, bn_scale=None, bn_shift=None,
output_permutation=None):
# if we're combining with a batch norm, fold the batch norm scale factor into the filters
# with sharded convolutions, filters are not duplicated or permuted in any way.
scaled_filters = filters
if bn_scale is not None and bn_shift is not None:
scaled_filters = filters * bn_scale.reshape(1, -1, 1, 1)
# if we're combining with a batch norm, fold the batch norm shift factor into the biases
shifted_biases = biases
if bn_scale is not None and bn_shift is not None:
shifted_biases = biases * bn_scale + bn_shift
# all of this should happen somewhere in the CNNContext class
shard_size = ciphertext_shards[0].getBatchSize()
num_out_channels = filters.shape[1]
channel_size = mtx_size * mtx_size
if channel_size < shard_size:
channels_per_shard = shard_size // (mtx_size * mtx_size)
output_dup_factor = math.ceil(channels_per_shard / num_out_channels)
else:
output_dup_factor = 1
num_in_channels = filters.shape[0]
if permutation is None:
permutation = np.array(range(num_in_channels))
if output_permutation is None:
output_permutation = np.array(range(num_out_channels))
if len(permutation) != num_in_channels:
raise ValueError("incorrect number of input channels")
if len(output_permutation) != num_out_channels:
raise ValueError("incorrect number of output channels")
scaled_filters = scaled_filters[:, output_permutation, :, :]
shifted_biases = shifted_biases[output_permutation]
# compute the convolution
conv_shards = conv2d_cpp(ciphertext_shards, scaled_filters, mtx_size, permutation)
repeated_shifted_biases = np.repeat(shifted_biases, mtx_size * mtx_size * output_dup_factor)
for s in range(len(conv_shards)):
conv_shards[s] += repeated_shifted_biases[s * shard_size: (s + 1) * shard_size]
return conv_shards

View File

@@ -0,0 +1,28 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import numpy as np
from pyOpenFHE import CKKS as pal_ckks
linear_cpp = pal_ckks.CNN.linear
def linear(channel_shards, weights, biases, mtx_size, permutation=None, scale=1.0, pool_factor=1):
shard_size = channel_shards[0].getBatchSize()
num_shards = len(channel_shards)
num_inputs = weights.shape[1]
channel_size = mtx_size * mtx_size
duplication_factor = max(shard_size // num_inputs, 1)
num_physical_channels_per_shard = shard_size // channel_size
num_physical_channels = num_physical_channels_per_shard * num_shards
num_logical_channels = num_physical_channels // duplication_factor
if permutation is None:
permutation = np.array(range(num_logical_channels))
output = linear_cpp(channel_shards, weights * scale, mtx_size, permutation, pool_factor)
# FO: if np.all(biases==0), then we do not need to compute biases*scale
num_out_activs = biases.shape[0]
output += np.pad(biases * scale, [(0, shard_size - num_out_activs)])
return output

View File

@@ -0,0 +1,72 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
from .utils import *
import math
from pyOpenFHE import CKKS as pal
pool = pal.CNN.pool
def divide_chunks(l, n):
# looping till length l
for i in range(0, len(l), n):
yield l[i:i + n]
def interleave_lists(lists):
return [val for tup in zip(*lists) for val in tup]
def invert_permutation(P):
inverse_permutation = [0] * len(P)
for i, v in enumerate(P):
inverse_permutation[v] = i
return inverse_permutation
def compose_permutations(P1, P2):
if len(P1) != len(P2):
raise ValueError("permutations must have equal size")
permutation = [P1[P2[i]] for i in range(len(P1))]
return permutation
"""
metadata includes:
- the new channel permutation
- the duplication factor
- the new number of shards
"""
def get_pool_permutation(shards, num_channels, mtx_size):
initial_num_shards = len(shards)
shard_size = shards[0].getBatchSize()
channel_size = mtx_size * mtx_size
initial_num_physical_channels_per_shard = math.ceil(shard_size / channel_size)
num_physical_channels = initial_num_shards * initial_num_physical_channels_per_shard
initial_dup_factor = math.ceil(num_physical_channels / num_channels)
# if we have channel sharding, then no permutation
if channel_size >= shard_size:
C = num_channels
P = list(range(C))
return P
if (initial_dup_factor > 1) and (initial_num_shards > 1):
raise ValueError("Should not have both duplication and shards at the same time")
# if we have duplication, then no permutation
if initial_dup_factor > 1:
C = initial_num_physical_channels_per_shard // initial_dup_factor
P = list(range(C))
return P
C = initial_num_physical_channels_per_shard * initial_num_shards
I = list(range(C))
I = list(divide_chunks(I, initial_num_physical_channels_per_shard))
I = list(divide_chunks(I, 4))
P = [interleave_lists(J) for J in I]
P = sum(P, start=[])
return np.array(P)

View File

@@ -0,0 +1,105 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import math
import numpy as np
from pyOpenFHE import CKKS as pal
upsample_cpp = pal.CNN.upsample
def divide_chunks(l, n):
# looping till length l
for i in range(0, len(l), n):
yield l[i:i + n]
def interleave_lists(lists):
return [val for tup in zip(*lists) for val in tup]
def invert_permutation(P):
inverse_permutation = [0] * len(P)
for i, v in enumerate(P):
inverse_permutation[v] = i
return inverse_permutation
def compose_permutations(P1, P2):
if len(P1) != len(P2):
raise ValueError("permutations must have equal size")
permutation = [P1[P2[i]] for i in range(len(P1))]
return permutation
"""
metadata includes:
- the new channel permutation
- the duplication factor
- the new number of shards
"""
def get_upsample_permutation(shards, num_channels, mtx_size):
initial_num_shards = len(shards)
shard_size = shards[0].getBatchSize()
channel_size = mtx_size * mtx_size
initial_num_physical_channels_per_shard = math.ceil(shard_size / channel_size)
final_num_physical_channels_per_shard = math.ceil(shard_size / channel_size / 4)
num_physical_channels = initial_num_shards * initial_num_physical_channels_per_shard
initial_dup_factor = math.ceil(num_physical_channels / num_channels)
# if we start with channel sharding, then no permutation
if channel_size >= shard_size:
P = list(range(num_channels))
return P
if (initial_dup_factor > 1) and (initial_num_shards > 1):
raise ValueError("Should not have both duplication and shards at the same time")
# if we have duplication factor >= 4, then no permutation
if initial_dup_factor > 2:
P = list(range(num_channels))
return P
# if we have two-fold duplication
if initial_dup_factor == 2:
P = list(range(num_channels))
if num_channels == 1: return P
P = P[::2] + P[1::2]
return P
I = list(range(num_channels))
I = list(divide_chunks(I, initial_num_physical_channels_per_shard))
I = [list(divide_chunks(J, 4)) for J in I]
P = [interleave_lists(J) for J in I]
P = sum(P, start=[])
return np.array(P)
"""
This takes a permuted list of ciphertexts stored using channel sharding,
and it reorders them into the identity permutation.
mtx_size and permutation refer to the values after upsampling, not of the input shards
"""
def undo_channel_sharding_permutation(shards, num_channels, mtx_size, permutation):
num_shards = len(shards)
shard_size = shards[0].getBatchSize()
channel_size = mtx_size * mtx_size
if shard_size > channel_size:
raise ValueError("This function should only be called on a channel sharded image")
num_shards_per_channel = channel_size // shard_size
final_shards = [None for _ in range(num_shards)]
for i, x in enumerate(shards):
channel_idx = i // num_shards_per_channel
subshard_idx = i % num_shards_per_channel
correct_idx = permutation[channel_idx] * num_shards_per_channel + subshard_idx
final_shards[correct_idx] = x
return final_shards

View File

@@ -0,0 +1,228 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import shutil
from pathlib import Path
import numpy as np
import pyOpenFHE as pal
from pyOpenFHE import CKKS as pal_ckks
import numpy as np
from pathlib import Path
serial = pal_ckks.serial
def is_power_of_2(x):
return x > 0 and x & (x - 1) == 0
def next_power_of_2(n):
p = 1
if n and not (n & (n - 1)):
return n
while p < n:
p <<= 1
return p
def load_cc_and_keys(batch_size, mult_depth=10, scale_factor_bits=40, bootstrapping=False):
f = "{}-{}-{}-{}".format(batch_size, mult_depth, scale_factor_bits, int(bootstrapping))
path = Path("serialized") / f
P = (path / "PublicKey.bin").as_posix()
publicKey = pal_ckks.serial.DeserializeFromFile_PublicKey(P, pal_ckks.serial.SerType.BINARY)
P = (path / "PrivateKey.bin").as_posix()
secretKey = pal_ckks.serial.DeserializeFromFile_PrivateKey(P, pal_ckks.serial.SerType.BINARY)
keys = pal_ckks.KeyPair(publicKey, secretKey)
cc = publicKey.getCryptoContext()
P = (path / "EvalMultKey.bin").as_posix()
pal_ckks.serial.DeserializeFromFile_EvalMultKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
P = (path / "EvalAutomorphismKey.bin").as_posix()
pal_ckks.serial.DeserializeFromFile_EvalAutomorphismKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
if bootstrapping:
cc.evalBootstrapSetup()
return cc, keys
def save_cc_and_keys(cc, keys, path):
P = (path / "PublicKey.bin").as_posix()
assert pal_ckks.serial.SerializeToFile(P, keys.publicKey, pal_ckks.serial.SerType.BINARY)
P = (path / "PrivateKey.bin").as_posix()
assert pal_ckks.serial.SerializeToFile(P, keys.secretKey, pal_ckks.serial.SerType.BINARY)
P = (path / "EvalMultKey.bin").as_posix()
assert pal_ckks.serial.SerializeToFile_EvalMultKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
P = (path / "EvalAutomorphismKey.bin").as_posix()
assert pal_ckks.serial.SerializeToFile_EvalAutomorphismKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
def create_cc_and_keys(batch_size, mult_depth=10, scale_factor_bits=40, bootstrapping=False, save=False):
# We make use of palisade HE by creating a crypto context object
# this specifies things like multiplicative depth
cc = pal_ckks.genCryptoContextCKKS(
mult_depth, # number of multiplications you can perform
scale_factor_bits, # kindof like number of bits of precision
batch_size, # length of your vector, can be any power-of-2 up to 2^14
)
print(f"CKKS scheme is using ring dimension = {cc.getRingDimension()}, batch size = {cc.getBatchSize()}")
cc.enable(pal.enums.PKESchemeFeature.PKE)
cc.enable(pal.enums.PKESchemeFeature.KEYSWITCH)
cc.enable(pal.enums.PKESchemeFeature.LEVELEDSHE)
cc.enable(pal.enums.PKESchemeFeature.ADVANCEDSHE)
cc.enable(pal.enums.PKESchemeFeature.FHE)
# generate keys
keys = cc.keyGen()
cc.evalMultKeyGen(keys.secretKey)
cc.evalPowerOf2RotationKeyGen(keys.secretKey)
if bootstrapping:
cc.evalBootstrapSetup()
cc.evalBootstrapKeyGen(keys.secretKey)
if save:
f = "{}-{}-{}-{}".format(batch_size, mult_depth, scale_factor_bits, int(bootstrapping))
path = Path("serialized") / f
path.mkdir(parents=True, exist_ok=True)
save_cc_and_keys(cc, keys, path)
return cc, keys
def get_keys(mult_depth,
scale_factor_bits,
batch_size,
bootstrapping):
try:
cc, keys = load_cc_and_keys(batch_size,
mult_depth=mult_depth,
scale_factor_bits=scale_factor_bits,
bootstrapping=bootstrapping)
except:
cc, keys = create_cc_and_keys(batch_size,
mult_depth=mult_depth,
scale_factor_bits=scale_factor_bits,
bootstrapping=bootstrapping,
save=True)
return cc, keys
def get_filters_and_biases_from_conv2d(layer):
filters = layer.weight.detach().numpy()
if hasattr(layer, "bias") and layer.bias is not None:
biases = layer.bias.detach().numpy()
else:
# without bias
# same as number of output channels (each bias is broadcast over the channel)
biases = np.zeros((filters.shape[0],))
filters = filters.transpose(1, 0, 2, 3)
pad_to = next_power_of_2(filters.shape[0])
if pad_to is not None:
if filters.shape[0] < pad_to:
filters = np.concatenate(
[filters, np.zeros((pad_to - filters.shape[0],) + filters.shape[1:])]
)
return filters, biases
def get_scale_and_shift_from_bn(layer):
mu = layer.running_mean.detach().numpy()
var = layer.running_var.detach().numpy()
gamma = (
layer.weight.detach().numpy()
) # https://discuss.pytorch.org/t/getting-parameters-of-torch-nn-batchnorm2d-during-training/38913/3
beta = layer.bias.detach().numpy()
eps = layer.eps
sigma = np.sqrt(var + eps) # std dev
# compute scale factor
scale = gamma / sigma
# compute shift factor
shift = -gamma * mu / sigma + beta
return scale, shift
# needs to know either number of channels or matrix size
def get_weights_and_biases_from_linear(layer, mtx_size, bias, pool_factor=1):
nout = layer.weight.size(0)
weights = layer.weight.detach().numpy()
num_channels = weights.shape[1] // (mtx_size * mtx_size)
weights = weights.reshape(nout, num_channels, mtx_size, mtx_size)
weights = weights.reshape(nout, -1)
weights = np.repeat(weights, pool_factor * pool_factor, axis=1)
if bias:
biases = layer.bias.detach().numpy()
else:
biases = np.zeros(nout)
return weights, biases
# Given a model and an input, get intermediate layer output
def get_intermediate_output(model, layer, inputs):
layer_name = "layer"
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
layer.register_forward_hook(
get_activation(layer_name)
)
_ = model(inputs)
return activation[layer_name]
def compare_accuracy(keys, cnn_context, unencrypted, name="block", num_digits=4):
A = decrypt_and_reshape(cnn_context, keys.secretKey, cnn_context.mtx_size)
B = unencrypted.detach().cpu().numpy()[0]
diff = np.abs(A - B[cnn_context.permutation])
print(f"error in {name}:\nmax = {np.max(diff):.0{num_digits}f}\nmean = {np.mean(diff):.0{num_digits}f}")
def decrypt_and_reshape(cnn_context, secret_key, mtx_size):
cc = secret_key.getCryptoContext()
decrypted_output = [cc.decrypt(secret_key, ctxt) for ctxt in cnn_context.shards]
decrypted_output = np.hstack(decrypted_output)
num_out_chan = int(round(len(decrypted_output) / (mtx_size * mtx_size)))
decrypted_output = decrypted_output.reshape((num_out_chan, mtx_size, mtx_size))
decrypted_output = decrypted_output[0:: cnn_context.duplication_factor]
return decrypted_output
def serialize(cc, keys, ctxt):
path = Path("serialized")
path.mkdir(parents=True, exist_ok=True)
shutil.rmtree(path)
path.mkdir(parents=True, exist_ok=True)
assert serial.SerializeToFile("serialized/CryptoContext.bin", ctxt, serial.SerType.BINARY)
assert serial.SerializeToFile("serialized/ciphertext.bin", ctxt, serial.SerType.BINARY)
assert serial.SerializeToFile("serialized/PublicKey.bin", keys.publicKey, serial.SerType.BINARY)
assert serial.SerializeToFile("serialized/PrivateKey.bin", keys.secretKey, serial.SerType.BINARY)
assert serial.SerializeToFile_EvalMultKey_CryptoContext(cc, "serialized/EvalMultKey.bin", serial.SerType.BINARY)
assert serial.SerializeToFile_EvalAutomorphismKey_CryptoContext(cc, "serialized/EvalAutomorphismKey.bin",
serial.SerType.BINARY)
if __name__ == "__main__":
cc, keys = get_keys(mult_depth=34, scale_factor_bits=59, batch_size=32 * 32 * 32, bootstrapping=True)
print(cc.getBatchSize())
shard = cc.encrypt(keys.publicKey, [0.0 for _ in range(32768)])
serialize(cc, keys, shard)

View File

@@ -0,0 +1,265 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
from typing import Union, Tuple, List
class Square(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.square(x)
def moment(x: torch.Tensor, std: float, mean: float, deg: int = 4, eps: float = 1e-4) -> torch.Tensor:
N = x.shape[0]
return (1.0 / N) * torch.sum((x - mean) ** deg) / (std ** deg + eps)
def activation_helper(activation: str = 'gelu',
gelu_degree: int = 16):
if activation == 'relu':
return nn.ReLU()
elif activation == 'gelu':
return nn.GELU()
elif activation == 'polygelu':
raise ValueError("Not supported.")
elif activation == 'square':
return Square()
else:
return nn.ReLU()
def conv_block(in_ch: int,
out_ch: int,
activation: str = 'relu',
gelu_degree: int = 16,
pool: bool = False,
pool_method: str = 'avg',
kernel: int = 3,
stride: int = 1,
padding: Union[int, str] = 1):
layers = [nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel,
stride=stride,
padding=padding),
nn.BatchNorm2d(out_ch),
activation_helper(activation, gelu_degree)
]
if pool:
layers.append(nn.MaxPool2d(2, 2) if pool_method == 'max' else nn.AvgPool2d(2, 2))
return nn.Sequential(*layers)
def get_small_model_dict(activation='gelu',
gelu_degree: int = 16,
pool_method: str = 'avg') -> nn.ModuleDict:
classifier = nn.Sequential(nn.Flatten(),
nn.Linear(8 * 8 * 128, 10))
return nn.ModuleDict(
{
"conv1": conv_block(in_ch=1,
out_ch=64,
kernel=4,
pool=True,
pool_method=pool_method,
padding='same',
activation=activation,
gelu_degree=gelu_degree),
"conv2": conv_block(in_ch=64,
out_ch=128,
kernel=4,
pool=True,
pool_method=pool_method,
padding='same',
activation=activation,
gelu_degree=gelu_degree),
"conv3": conv_block(in_ch=128,
out_ch=128,
kernel=4,
pool=False,
padding='same',
activation=activation,
gelu_degree=gelu_degree),
"classifier": classifier
}
)
class SmallModel(nn.Module):
def __init__(self, activation='gelu', gelu_degree: int = 16, pool_method: str = 'avg'):
super(SmallModel, self).__init__()
self.model_layers = get_small_model_dict(activation=activation, gelu_degree=gelu_degree,
pool_method=pool_method)
self.n_bn_classes = self.count_instances_of_a_class()
def count_instances_of_a_class(self, cls: nn.BatchNorm2d = nn.BatchNorm2d) -> int:
n_classes = 0
for _, block in self.model_layers.items():
for layer in block:
# Handle the nested case
if isinstance(layer, nn.Sequential):
for sublayer in layer:
if isinstance(sublayer, cls):
n_classes += 1
# Handle the unnested case
else:
if isinstance(layer, cls):
n_classes += 1
return n_classes
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.bn_outputs = {} # key=layer name, v=list of torch.Tensors
self.outputs = {}
for name, block in self.model_layers.items():
block_output, block_bn_output = self.block_pass(block, x)
self.bn_outputs[name] = block_bn_output
# Residual Connection
if "res" in name:
x = x + block_output
# Normal
else:
x = block_output
return x
def block_pass(self, block: nn.Sequential, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
bn_output = []
# Iterate through a block, which may be nested (residual connections are nested)
for layer in block:
# Handle the nested case
if isinstance(layer, nn.Sequential):
for sublayer in layer:
x = sublayer(x)
if isinstance(sublayer, nn.BatchNorm2d):
bn_output.append(x)
# Handlle the unnested case
else:
x = layer(x)
if isinstance(layer, nn.BatchNorm2d):
bn_output.append(x)
self.outputs[layer] = x
return x, bn_output
# Must be called after forward method to set self.bn_outputs
def get_bn_loss_metrics(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
means, stds, skews, kurts = self.get_moments_by_layer()
# Aggregating
loss_means = F.mse_loss(means, torch.zeros(self.n_bn_classes))
loss_stds = F.mse_loss(stds, torch.ones(self.n_bn_classes))
# loss_skews = F.mse_loss(skews, torch.zeros(self.n_bn_classes))
loss_kurts = F.mse_loss(kurts, 3 * torch.ones(self.n_bn_classes))
return loss_means, loss_stds, loss_kurts
def get_moments_by_layer(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
means, stds = torch.zeros(self.n_bn_classes), torch.zeros(self.n_bn_classes)
skews, kurts = torch.zeros(self.n_bn_classes), torch.zeros(self.n_bn_classes)
layer_index = 0
for name, block in self.bn_outputs.items():
# Residual blocks are nested
for sublayer in range(0, len(block), 1):
dist = block[sublayer].flatten()
std, mean = torch.std_mean(dist)
skew = moment(dist, std, mean, deg=3)
kurt = moment(dist, std, mean, deg=4)
means[layer_index] = mean
stds[layer_index] = std
skews[layer_index] = skew
kurts[layer_index] = kurt
layer_index += 1
return means, stds, skews, kurts
def get_intermediate_layer_output(self, layer, inputs):
layer_name = "layer"
activation = {}
def get_activation(name):
def hook(self, input, output):
print("calling hook")
activation[name] = output.detach()
return hook
layer.register_forward_hook(
get_activation(layer_name)
)
_ = self(inputs)
return activation[layer_name]
def train_small_model():
DATA_DIR = "../data"
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Pad(2)
])
BATCH_SIZE = 512
train_kwargs = {'batch_size': BATCH_SIZE}
test_kwargs = {'batch_size': BATCH_SIZE}
dataset1 = datasets.MNIST(DATA_DIR, train=True, download=True,
transform=transform)
dataset2 = datasets.MNIST(DATA_DIR, train=False,
transform=transform)
train_dl = torch.utils.data.DataLoader(dataset1,**train_kwargs)
val_dl = torch.utils.data.DataLoader(dataset2, **test_kwargs)
max_lr = 0.005
weight_decay = 1e-5
model = SmallModel()
optimizer = torch.optim.Adam(model.parameters(),
max_lr,
weight_decay = weight_decay)
EPOCHS = 13
for epoch in range(EPOCHS):
print(f"Epoch {epoch}")
train_loss = 0
bn_means, bn_stds, bn_kurts = 0,0,0
N = 0
model.train()
for i, (img, label) in enumerate(train_dl):
logit = model(img)
# Model loss
loss = F.cross_entropy(logit,label)
# Loss modifications
bn_mean, bn_std, bn_kurt = model.get_bn_loss_metrics()
loss += (bn_mean + bn_std + bn_kurt)
loss.backward()
# Save stuff
train_loss += loss.item()
bn_means += bn_mean.item()
bn_stds += bn_std.item()
bn_kurts += bn_kurt.item()
N += 1
optimizer.step()
optimizer.zero_grad()
print("Saving model as %s" % "small_model.pt")
torch.save(model.state_dict(), "small_model.pt")
if __name__ == "__main__":
train_small_model()

View File

@@ -0,0 +1,93 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
# export OMP_DISPLAY_ENV=TRUE
import os
from time import time
import torch
import torchvision
import torchvision.transforms as transforms
from cnn_context import create_cnn_context
from he_cnn.utils import *
from small_model import SmallModel
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
# create HE cc and keys
mult_depth = 30
scale_factor_bits = 40
batch_size = 32 * 32 * 32 # increased batch size b/c the ring dimension is higher due to the mult_depth
# used for a small test of big shards
# batch_size = 128
# if using bootstrapping, you must increase scale_factor_bits to 59
cc, keys = create_cc_and_keys(batch_size, mult_depth=mult_depth, scale_factor_bits=scale_factor_bits,
bootstrapping=False)
# load the model
weight_file = "palisade_he_cnn/src/weights/small_model.pt"
print(os.getcwd())
model = SmallModel(activation='gelu', pool_method='avg')
model.load_state_dict(torch.load(weight_file))
model.eval()
# load data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Pad(2)])
validset = torchvision.datasets.MNIST(root="./data", download=True, transform=transform)
validloader = torch.utils.data.DataLoader(validset, batch_size=1, shuffle=True)
total = 0
correct = 0
total_time = 0
for i, test_data in enumerate(validloader):
print(f"Inference {i + 1}:")
x_test, y_test = test_data
input_img = create_cnn_context(x_test[0], cc, keys.publicKey, verbose=True)
start = time()
layer = model.model_layers.conv1
conv1 = input_img.apply_conv(layer[0], layer[1])
act1 = conv1.apply_gelu()
pool1 = act1.apply_pool()
layer = model.model_layers.conv2
perm = np.random.permutation(128) # example of how to use an output permutation
conv2 = pool1.apply_conv(layer[0], layer[1], output_permutation=perm)
act2 = conv2.apply_gelu()
pool2 = act2.apply_pool()
layer = model.model_layers.conv3
conv3 = pool2.apply_conv(layer[0], layer[1])
act3 = conv3.apply_gelu()
layer = model.model_layers.classifier[1]
logits = act3.apply_linear(layer)
logits_dec = cc.decrypt(keys.secretKey, logits)[:10]
logits_pt = model(x_test).detach().numpy().ravel()
print(f"[+] decrypted logits = {logits_dec}")
print(f"[+] unencrypted logits = {logits_pt}")
inference_time = time() - start
total_time += inference_time
total += 1
y_label = y_test[0]
correct += np.argmax(logits_dec) == y_label
out_string = f"""
Count: {total}
Accuracy: {correct / total}
Average latency: {total_time / total:.02f}s
"""
print(out_string)

View File

@@ -0,0 +1,138 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import copy
from collections import OrderedDict, defaultdict
from typing import Dict, Callable
import torch
import torchvision.transforms as tt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
def pad_conv_input_channels(conv1):
conv1 = copy.deepcopy(conv1)
conv1.in_channels = 4
data = conv1.weight.data
shape = list(data.shape)
shape[1] = 1
padding = torch.zeros(*shape)
new_data = torch.cat((data, padding), 1)
conv1.weight.data = torch.Tensor(new_data)
return conv1
class PadChannel(object):
def __init__(self, npad: int=1):
self.n = npad
def __call__(self, x):
_, width, height = x.shape
x = torch.cat([x, torch.zeros(self.n, width, height)])
return x
def patch_whitening(data, patch_size=(3, 3)):
# Compute weights from data such that
# torch.std(F.conv2d(data, weights), dim=(2, 3))
# is close to 1.
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
def get_cifar10_dataloader(batch_size,
data_dir: str='../../datasets/cifar10/',
num_workers: int=4):
stats = ((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([
tt.RandomCrop(32,padding=4,padding_mode='reflect'),
tt.RandomHorizontalFlip(),
tt.ToTensor(),
tt.Normalize(*stats,inplace=True),
PadChannel(npad=1)
])
val_tfms = tt.Compose([
tt.ToTensor(),
tt.Normalize(*stats,inplace=True),
PadChannel(npad=1)
])
train_ds = ImageFolder(data_dir+'train',transform=train_tfms)
val_ds = ImageFolder(data_dir+'test',transform=val_tfms)
train_dl = DataLoader(train_ds,
batch_size,
pin_memory = True,
num_workers = num_workers,
shuffle = True)
val_dl = DataLoader(val_ds,
batch_size,
pin_memory = True,
num_workers = num_workers)
return train_dl, val_dl
def remove_all_hooks(model: torch.nn.Module) -> None:
for name, child in model._modules.items():
if child is not None:
if hasattr(child, "_forward_hooks"):
child._forward_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_forward_pre_hooks"):
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_backward_hooks"):
child._backward_hooks: Dict[int, Callable] = OrderedDict()
remove_all_hooks(child)
# Given a model and an input, get intermediate layer output
def get_intermediate_output(model):
activation = defaultdict(list)
def get_activation(name):
def hook(model, input, output):
x = output.detach()
activation[name].append(x)
return hook
BatchNorm_layers = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
for i, b in enumerate(BatchNorm_layers):
b.register_forward_hook(
get_activation(f"bn_{i + 1}")
)
return activation
def get_all_bn_activations(model, val_dl, DEVICE):
activation = get_intermediate_output(model)
model.to(DEVICE)
model.eval()
for img, label in (val_dl):
img, label = img.to(DEVICE), label.to(DEVICE)
out = model(img)
remove_all_hooks(model)
activation = {k:torch.cat(v) for k,v in activation.items()}
return activation

Binary file not shown.

248
palisade_he_cnn/test.py Normal file
View File

@@ -0,0 +1,248 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import pytest
import torch
import numpy as np
from src.cnn_context import create_cnn_context
from src.he_cnn.utils import *
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
class Info():
def __init__(self, mult_depth = 30, scale_factor_bits = 40, batch_size = 32 * 32 * 32, max = 255, min = 0, h = 128, w = 128, channel_size = 3, ker_size = 3):
self.mult_depth = mult_depth
self.scale_factor_bits = scale_factor_bits
self.batch_size = batch_size
self.max = max
self.min = min
self.h = h
self.w = w
self.channel_size = channel_size
self.ker_size = ker_size
rand_tensor = (max-min)*torch.rand((channel_size, h, w)) + min
self.rand_tensor = rand_tensor
self.cc, self.keys = create_cc_and_keys(batch_size, mult_depth=mult_depth, scale_factor_bits=scale_factor_bits, bootstrapping=False)
self.input_img = create_cnn_context(self.rand_tensor, self.cc, self.keys.publicKey, verbose=True)
@pytest.fixture
def check1():
return Info(30, 40, 32 * 32 * 32, 1, -1, 64, 64, 4, 3)
@pytest.fixture
def check2():
return Info(30, 40, 32 * 32 * 32, 1, -1, 64, 64, 1, 3)
@pytest.fixture
def check3():
return Info(30, 40, 32, 1, -1, 16, 16, 2, 3)
def test_apply_conv2d_c1(check1) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check1.channel_size, check1.channel_size, check1.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check1.rand_tensor)
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
conv1 = check1.input_img.apply_conv(layer)
dec_conv1 = conv1.decrypt_to_tensor(check1.cc, check1.keys).numpy().squeeze()
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed image < shard"
def test_apply_conv2d_c2(check2) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check2.channel_size, check2.channel_size, check2.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check2.rand_tensor)
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
conv1 = check2.input_img.apply_conv(layer)
dec_conv1 = conv1.decrypt_to_tensor(check2.cc, check2.keys).numpy().squeeze()
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed channel < shard"
def test_apply_conv2d_c3(check3) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check3.channel_size, check3.channel_size, check3.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check3.rand_tensor)
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
conv1 = check3.input_img.apply_conv(layer)
dec_conv1 = conv1.decrypt_to_tensor(check3.cc, check3.keys).numpy().squeeze()
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed channel > shard"
def test_apply_pool_c1(check1) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check1.channel_size, check1.channel_size, check1.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check1.rand_tensor)
pt_max_pool = torch.nn.AvgPool2d(2)
pt_pool = pt_max_pool(pt_conv)
pt_pool = pt_pool.detach().numpy()
conv1 = check1.input_img.apply_conv(layer)
pool = conv1.apply_pool()
dec_pool = pool.decrypt_to_tensor(check1.cc, check1.keys).numpy()
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed image < shard"
def test_apply_pool_c2(check2) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check2.channel_size, check2.channel_size, check2.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check2.rand_tensor)
pt_max_pool = torch.nn.AvgPool2d(2)
pt_pool = pt_max_pool(pt_conv)
pt_pool = pt_pool.detach().numpy()
conv1 = check2.input_img.apply_conv(layer)
pool = conv1.apply_pool()
dec_pool = pool.decrypt_to_tensor(check2.cc, check2.keys).numpy()
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed channel < shard"
def test_apply_pool_c3(check3) -> None:
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size):
super(ConvLayer, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
def forward(self, x):
x = self.conv(x)
return x
model = ConvLayer(check3.channel_size, check3.channel_size, check3.ker_size)
model.eval()
layer = model.conv
pt_conv = model(check3.rand_tensor)
pt_max_pool = torch.nn.AvgPool2d(2)
pt_pool = pt_max_pool(pt_conv)
pt_pool = pt_pool.detach().numpy()
conv1 = check3.input_img.apply_conv(layer)
pool = conv1.apply_pool()
dec_pool = pool.decrypt_to_tensor(check3.cc, check3.keys).numpy()
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed channel > shard"
def test_apply_linear_c1(check1) -> None:
class LinearLayer(torch.nn.Module):
def __init__(self, input_size, output_size):
super(LinearLayer, self).__init__()
self.linear_one = torch.nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear_one(x)
return x
linear = LinearLayer(len(check1.rand_tensor.flatten()), check1.rand_tensor.shape[0])
linear.eval()
pt_linear = linear(check1.rand_tensor.flatten()).detach().numpy()
he_linear = check1.input_img.apply_linear(linear.linear_one)
dec_linear = check1.cc.decrypt(check1.keys.secretKey, he_linear)[0:check1.rand_tensor.shape[0]]
assert np.allclose(dec_linear, pt_linear, atol=1e-03), "Linear result did not match between HE and PyTorch, failed image < shard"
def test_apply_linear_c2(check2) -> None:
class LinearLayer(torch.nn.Module):
def __init__(self, input_size, output_size):
super(LinearLayer, self).__init__()
self.linear_one = torch.nn.Linear(input_size, output_size)
def forward(self, x):
x = self.linear_one(x)
return x
linear = LinearLayer(len(check2.rand_tensor.flatten()), check2.rand_tensor.shape[0])
linear.eval()
pt_linear = linear(check2.rand_tensor.flatten()).detach().numpy()
he_linear = check2.input_img.apply_linear(linear.linear_one)
dec_linear = check2.cc.decrypt(check2.keys.secretKey, he_linear)[0:check2.rand_tensor.shape[0]]
assert np.allclose(dec_linear, pt_linear, atol=1e-03), "Linear result did not match between HE and PyTorch, failed channel < shard"
def test_apply_gelu_c1(check1) -> None:
gelu = torch.nn.GELU()
pt_gelu = gelu(check1.rand_tensor)
he_gelu = check1.input_img.apply_gelu()
dec_gelu = he_gelu.decrypt_to_tensor(check1.cc, check1.keys).numpy()
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed image < shard"
def test_apply_gelu_c2(check2) -> None:
gelu = torch.nn.GELU()
pt_gelu = gelu(check2.rand_tensor)
he_gelu = check2.input_img.apply_gelu()
dec_gelu = he_gelu.decrypt_to_tensor(check2.cc, check2.keys).numpy()
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed channel < shard"
def test_apply_gelu_c3(check3) -> None:
gelu = torch.nn.GELU()
pt_gelu = gelu(check3.rand_tensor)
he_gelu = check3.input_img.apply_gelu()
dec_gelu = he_gelu.decrypt_to_tensor(check3.cc, check3.keys).numpy()
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed channel > shard"

View File

@@ -0,0 +1,75 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import ImageFile
from tqdm import tqdm
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
# Set hyperparameters
num_epochs = 1
batch_size = 128
learning_rate = 0.001
# Initialize transformations for data augmentation
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=45),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Load the ImageNet Object Localization Challenge dataset
train_dataset = torchvision.datasets.ImageFolder(
root='~/ImageNet/ILSVRC/Data/CLS-LOC/train',
transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
# Load the ResNet50 model
model = torchvision.models.resnet50(weights='DEFAULT')
# Parallelize training across multiple GPUs
model = torch.nn.DataParallel(model, device_ids = device_ids)
# Set the model to run on the device
model = model.to(f'cuda:{model.device_ids[0]}')
# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Train the model...
for epoch in range(num_epochs):
for inputs, labels in tqdm(train_loader):
# Move input and label tensors to the device
inputs = inputs.to(f'cuda:{model.device_ids[0]}')
labels = labels.to(f'cuda:{model.device_ids[0]}')
# Zero out the optimizer
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, labels)
# Backward pass
loss.backward()
optimizer.step()
# Print the loss for every epoch
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')
print(f'Finished Training, Loss: {loss.item():.4f}')

View File

@@ -0,0 +1,130 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import torch
import torch.nn as nn
class Scale(nn.Module):
def __init__(self, scale: float = 0.125):
super().__init__()
self.scale = scale
def forward(self, x) -> torch.Tensor:
return self.scale * x
def conv_block(
in_ch: int,
out_ch: int,
kernel_size: int = 3,
stride: int = 1,
padding: str = "same",
pool: bool = False,
gelu: bool = False
):
layers = [nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding),
nn.BatchNorm2d(out_ch)
]
if pool:
layers.append(nn.AvgPool2d(2, 2))
if gelu:
layers.append(nn.GELU())
return nn.Sequential(*layers)
class ResNet9(nn.Module):
def __init__(self,
c_in: int = 4,
c_out: int = 36,
num_classes: int = 10,
scale_out: float = 0.125):
super().__init__()
self.c_out = c_out
self.conv1 = nn.Conv2d(c_in,
c_out,
kernel_size=(3, 3),
padding="same",
bias=True)
self.conv2 = conv_block(c_out,
64,
kernel_size=1,
padding="same",
pool=False,
gelu=True)
self.conv3 = conv_block(64,
128,
kernel_size=3,
padding="same",
pool=True,
gelu=True)
self.res1 = nn.Sequential(
conv_block(128,
128,
kernel_size=3,
padding="same",
pool=False,
gelu=True),
conv_block(128,
128,
kernel_size=3,
padding="same",
pool=False,
gelu=True)
)
self.conv4 = conv_block(128,
256,
kernel_size=3,
padding="same",
pool=True,
gelu=True)
self.conv5 = conv_block(256,
512,
kernel_size=3,
padding="same",
pool=True,
gelu=True)
self.res2 = nn.Sequential(
conv_block(512,
512,
kernel_size=3,
padding="same",
pool=False,
gelu=True),
conv_block(512,
512,
kernel_size=3,
padding="same",
pool=False,
gelu=True)
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512 * 4 * 4, num_classes, bias=True),
Scale(scale_out)
)
def set_conv1_weights(self,
weights: torch.Tensor,
bias: torch.Tensor):
self.conv1.weight.data = weights
self.conv1.weight.requires_grad = False
self.conv1.bias.data = bias
self.conv1.bias.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
res1 = self.res1(x)
x = x + res1
x = self.conv4(x)
x = self.conv5(x)
res2 = self.res2(x)
x = x + res2
return self.classifier(x)

View File

@@ -0,0 +1,321 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
from pathlib import Path
from typing import Any, List, Optional, Type
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
# Low-Complexity deep convolutional neural networks on
# FHE using multiplexed parallel convolutions
#
# https://eprint.iacr.org/2021/1688.pdf
#
# Our implementation is not an exact 1-to-1
__all__ = [
"resnet_test",
"resnet20",
"resnet32",
"resnet44",
"resnet56",
"resnet110",
]
POOL = nn.AvgPool2d(2, 2)
BN_MOMENTUM = 0.1
class Debug(nn.Module):
def __init__(self, filename="temp.txt", debug=False):
super().__init__()
self.filename = Path("debug") / filename
self.debug = debug
# print(self.debug, filename)
def forward(self, x) -> torch.Tensor:
if self.debug:
data = x.detach().cpu().numpy().ravel()
np.savetxt(self.filename, data, fmt="%0.04f")
return x
class Scale(nn.Module):
def __init__(self, scale: float = 0.125):
super().__init__()
self.scale = scale
def forward(self, x) -> torch.Tensor:
return self.scale * x
def conv_bn(inchan: int,
outchan: int,
kernel: int = 3,
stride: int = 1,
padding: str = "same",
filenames: list = ["temp.txt"],
debug: bool = False) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
inchan,
outchan,
kernel_size=kernel,
stride=1,
padding=padding
),
nn.BatchNorm2d(outchan, momentum=BN_MOMENTUM),
Debug(filenames[0], debug)
)
def conv_bn_down(inchan: int,
outchan: int,
kernel: int = 3,
stride: int = 1,
padding: str = "same",
filenames: list = ["temp.txt", "temp.txt"],
debug: bool = False, ) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
inchan,
outchan,
kernel_size=kernel,
stride=1,
padding=padding
),
nn.BatchNorm2d(outchan, momentum=BN_MOMENTUM),
Debug(filenames[0], debug),
POOL,
Debug(filenames[1], debug)
)
class BasicBlock(nn.Module):
def __init__(
self,
inchan: int,
outchan: int,
kernel: int = 3,
stride: int = 1,
padding: str = "same",
activation: nn.Module = nn.GELU(),
downsample: Optional[nn.Module] = None,
prefix: str = "l1",
debug: bool = False
) -> None:
super().__init__()
# If a skip module is defined (defined as downsample), then our first block
# in series needs to also include a downsampling operation, aka pooling.
if downsample is not None:
self.conv_bn_1 = conv_bn_down(
inchan=inchan,
outchan=outchan,
kernel=3,
stride=1,
padding=padding,
filenames=["%s_bn1.txt" % prefix,
"%s_pool.txt" % prefix],
debug=debug
)
else:
self.conv_bn_1 = conv_bn(
inchan=outchan,
outchan=outchan,
kernel=3,
stride=1,
padding=padding,
filenames=["%s_bn1.txt" % prefix],
debug=debug
)
self.conv_bn_2 = conv_bn(
inchan=outchan,
outchan=outchan,
kernel=3,
stride=1,
padding=padding,
filenames=["%s_bn2.txt" % prefix],
debug=debug
)
self.gelu = activation
self.downsample = downsample
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv_bn_1(x)
out = self.gelu(out)
out = self.conv_bn_2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
if self.downsample is not None:
out = self.gelu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[BasicBlock],
layers: List[int],
num_classes: int = 10,
debug: bool = False
):
super().__init__()
self.debug = debug
self.conv_bn_1 = nn.Sequential(
nn.Conv2d(4, 16, kernel_size=3, stride=1, padding="same"),
nn.BatchNorm2d(16, momentum=BN_MOMENTUM),
Debug("l0_bn1.txt", debug=debug)
)
self.gelu0 = nn.GELU()
self.gelu1 = nn.GELU()
self.gelu2 = nn.GELU()
self.gelu3 = nn.GELU()
self.debug0 = Debug("l0_gelu.txt", debug)
self.debug1 = Debug("l1_gelu.txt", debug)
self.debug2 = Debug("l2_gelu.txt", debug)
self.debug3 = Debug("l3_gelu.txt", debug)
self.layer1 = self._make_layer(
block=block,
inchan=16,
outchan=16,
nblocks=layers[0],
stride=1,
prefix="l1"
)
self.layer2 = self._make_layer(
block,
inchan=16,
outchan=32,
nblocks=layers[1],
stride=2, # Triggers downsample != None, not a true stride
prefix="l2"
)
self.layer3 = self._make_layer(
block,
inchan=32,
outchan=32,
nblocks=layers[2],
stride=2, # Triggers downsample != None, not a true stride
prefix="l3"
)
self.classifier = nn.Sequential(
POOL,
nn.Flatten(),
nn.Linear(32 * 4 * 4, num_classes),
Scale()
)
def forward(self, x: Tensor) -> Tensor:
x = self.conv_bn_1(x)
x = self.gelu0(x)
x = self.debug0(x)
x = self.layer1(x)
x = self.gelu1(x)
x = self.debug1(x)
x = self.layer2(x)
x = self.gelu2(x)
x = self.debug2(x)
x = self.layer3(x)
x = self.gelu3(x)
x = self.debug3(x)
return self.classifier(x)
def _make_layer(
self,
block: Type[BasicBlock],
inchan: int,
outchan: int,
nblocks: int,
stride: int = 1,
prefix: str = "l1"
) -> nn.Sequential:
downsample = None
if stride != 1:
downsample = conv_bn_down(
inchan=inchan,
outchan=outchan,
filenames=["%s_ds_bn1.txt" % prefix,
"%s_ds_pool.txt" % prefix],
debug=self.debug,
kernel=3,
stride=1,
padding="same"
)
layers = []
for i in range(0, nblocks):
# Only need it for first iter
if i == 1:
downsample = None
layers.append(
block(
inchan=inchan,
outchan=outchan,
kernel=3,
stride=stride,
padding="same",
activation=nn.GELU(),
downsample=downsample,
prefix=prefix + "_%s" % str(i),
debug=self.debug
)
)
return nn.Sequential(*layers)
def _resnet(
block: Type[BasicBlock],
layers: List[int],
**kwargs: Any,
) -> ResNet:
return ResNet(block, layers, **kwargs)
def resnet_test(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [1, 1, 1], **kwargs)
def resnet20(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [3, 3, 3], **kwargs)
def resnet32(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [5, 5, 5], **kwargs)
def resnet44(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [7, 7, 7], **kwargs)
def resnet56(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [9, 9, 9], **kwargs)
def resnet110(**kwargs: Any) -> ResNet:
return _resnet(BasicBlock, [18, 18, 18], **kwargs)

View File

@@ -0,0 +1,68 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
def get_optuna_params(model_type: str, dataset: str) -> dict:
params = {}
if dataset=='CIFAR10':
if model_type=='resnet20':
params["lr"] = 0.0016822249163093617
params["lr_bias"] = 63.934695046801245
params["momentum"] = 0.8484574950771097
params["weight_decay"] = 0.11450934135118791
elif model_type=='resnet32':
# 91.19: lr: 0.0013205254360784781, lr_bias: 61.138281101282544, momentum: 0.873508553678625, weight_decay: 0.26911634559915815
# 91.1 : lr: 0.0013978655308274968, lr_bias: 70.43940111170473, momentum: 0.8611100787383372, weight_decay: 0.2604742590264777
# 90.99: lr: 0.0019695910893940986, lr_bias: 60.930501987151686, momentum: 0.8831260271578129, weight_decay: 0.1456126229025426
params["lr"] = 0.0013205254360784781
params["lr_bias"] = 61.138281101282544
params["momentum"] = 0.873508553678625
params["weight_decay"] = 0.26911634559915815
elif model_type=='resnet44':
# 91.49: 0.0017177668853317557 72.4258603207131 0.8353896320183106 0.16749858871622
# 91.16: 0.0019608745758959625 67.9132255882833 0.8041541468923449 0.19517278422517992
# 91.01: 0.0009350979452929332 71.95838038824016 0.858379476548086 0.06780300392316674
params["lr"] = 0.0017177668853317557
params["lr_bias"] = 72.4258603207131
params["momentum"] = 0.8353896320183106
params["weight_decay"] = 0.16749858871622
elif model_type=='resnet56':
# 92.12: 0.0012022823706985977 71.31108702685964 0.8252747623136261 0.26463818739336625
# 91.90: 0.0010850336892236205 55.20534833175523 0.8738224946147084 0.10705317777179325
# 91.56: 0.0019151327847040805 63.38376732305882 0.9134938189630787 0.24446065595718675
params["lr"] = 0.0012022823706985977
params["lr_bias"] = 71.31108702685964
params["momentum"] = 0.8252747623136261
params["weight_decay"] = 0.26463818739336625
elif model_type=='resnet110':
# 92.23: 0.001477698037686629 61.444988882569774 0.7241645867415002 0.23586225065185779
# 92.17: 0.0017110807237653582 65.2511959805971 0.8078620231092996 0.19065715813207001
# 92.16: 0.0015513227695282382 59.89497310126697 0.7355843250067341 0.13248840913478463
params["lr"] = 0.001477698037686629
params["lr_bias"] = 61.444988882569774
params["momentum"] = 0.7241645867415002
params["weight_decay"] = 0.23586225065185779
else:
print("model_type and dataset are incorrectly specified. Returning resnet20 params.")
params["lr"] = 0.0016822249163093617
params["lr_bias"] = 63.934695046801245
params["momentum"] = 0.8484574950771097
params["weight_decay"] = 0.11450934135118791
else:
# only return resnet32 since CIFAR100 only needs this
if model_type=='resnet32':
# 65.09: 0.0018636209167742187 64.96657354785438 0.9186032548289501 0.15017464467868924
# 64.70: 0.0017509006966116355 60.10884856596049 0.8921508582343675 0.10919043636429121
# 64.49: 0.0015358614659175514 59.175398449172015 0.8553794786037812 0.20824545084283141
params["lr"] = 0.0018636209167742187
params["lr_bias"] = 64.96657354785438
params["momentum"] = 0.9186032548289501
params["weight_decay"] = 0.15017464467868924
else:
# Default bag of trick params
params["lr"] = 0.001
params["lr_bias"] = 64
params["momentum"] = 0.9
params["weight_decay"] = 0.256
print("Loading params for %s, %s" % (model_type, dataset))
print(params)
return params

View File

@@ -0,0 +1,308 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import argparse
import copy
import json
import time
from palisade_he_cnn.training.models.resnet9 import ResNet9
from palisade_he_cnn.training.utils.utils_dataloading import *
from palisade_he_cnn.training.utils.utils_kurtosis import *
from palisade_he_cnn.training.utils.utils_resnetN import (
patch_whitening, update_nesterov, update_ema, label_smoothing_loss
)
def argparsing():
parser = argparse.ArgumentParser()
parser.add_argument('-bs', '--batch',
help='Batch size',
type=int,
required=False,
default=512)
parser.add_argument('-e', '--epochs',
help='Number of epochs',
type=int,
required=False,
default=100)
parser.add_argument('-c', '--cuda',
help='CUDA device number',
type=int,
required=False,
default=0)
parser.add_argument('-r', '--nruns',
help='Number of training runs',
type=int,
required=False,
default=5)
parser.add_argument('-dataset', '--dataset',
help='CIFAR10 or CIFAR100',
type=str,
choices=['CIFAR10', 'CIFAR100'],
required=False,
default='CIFAR10')
parser.add_argument('-s', '--save',
help='Save model and log files',
type=bool,
required=False,
default=True)
return vars(parser.parse_args())
def train(
dataset,
epochs,
batch_size,
momentum,
weight_decay,
weight_decay_bias,
ema_update_freq,
ema_rho,
device,
dtype,
kwargs,
use_TTA,
seed=0
):
lr_schedule = torch.cat([
torch.linspace(0e+0, 2e-3, 194),
torch.linspace(2e-3, 2e-4, 582),
])
lr_schedule_bias = 64.0 * lr_schedule
kurt_schedule = torch.cat([
torch.linspace(0, 1e-1, 2000),
])
# Print information about hardware on first run
if seed == 0:
if device.type == "cuda":
print("Device :", torch.cuda.get_device_name(device.index))
print("Dtype :", dtype)
print()
# Start measuring time
start_time = time.perf_counter()
# Set random seed to increase chance of reproducability
torch.manual_seed(seed)
# Setting cudnn.benchmark to True hampers reproducability, but is faster
torch.backends.cudnn.benchmark = True
# Load dataset
if dataset == "CIFAR10":
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
else:
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
train_data = torch.cat(
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
valid_data = torch.cat(
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
temp = train_data[:10000, :, 4:-4, 4:-4]
weights = patch_whitening(temp)
train_model = ResNet9(c_in=weights.size(1),
c_out=weights.size(0),
**kwargs).to(device)
train_model.set_conv1_weights(
weights=weights.to(device),
bias=torch.zeros(weights.size(0)).to(device)
)
train_model.to(dtype)
# Convert BatchNorm back to single precision for better accuracy
for module in train_model.modules():
if isinstance(module, nn.BatchNorm2d):
module.float()
# Collect weights and biases and create nesterov velocity values
weights = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) > 1
]
biases = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) <= 1
]
# Copy the model for validation
valid_model = copy.deepcopy(train_model)
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
# Train and validate
print("\nepoch batch train time [sec] validation accuracy")
train_time = 0.0
batch_count = 0
best_acc = 0.0
best_model = None
for epoch in range(1, epochs + 1):
start_time = time.perf_counter()
# Randomly shuffle training data
indices = torch.randperm(len(train_data), device=device)
data = train_data[indices]
targets = train_targets[indices]
# Crop random 32x32 patches from 40x40 training data
data = [
random_crop(data[i: i + batch_size], crop_size=(32, 32))
for i in range(0, len(data), batch_size)
]
data = torch.cat(data)
# Randomly flip half the training data
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
for i in range(0, len(data), batch_size):
# discard partial batches
if i + batch_size > len(data):
break
# Slice batch from data
inputs = data[i: i + batch_size]
target = targets[i: i + batch_size]
batch_count += 1
# Compute new gradients
train_model.zero_grad()
train_model.train(True)
# kurtosis setup
remove_all_hooks(train_model)
activations = get_intermediate_output(train_model)
logits = train_model(inputs)
loss = label_smoothing_loss(logits, target, alpha=0.2)
# kurtosis scheduler
kurt_index = min(batch_count, len(kurt_schedule) - 1)
kurt_scale = kurt_schedule[kurt_index]
# kurtosis calculation and cleanup
remove_all_hooks(train_model)
activations = {k: torch.cat(v) for k, v in activations.items()}
loss_means, loss_stds, loss_kurts = get_statistics(activations)
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
loss.sum().backward()
lr_index = min(batch_count, len(lr_schedule) - 1)
lr = lr_schedule[lr_index]
lr_bias = lr_schedule_bias[lr_index]
# Update weights and biases of training model
update_nesterov(weights, lr, weight_decay, momentum)
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
# Update validation model with exponential moving averages
if (i // batch_size % ema_update_freq) == 0:
update_ema(train_model, valid_model, ema_rho)
# Add training time
train_time += time.perf_counter() - start_time
valid_correct = []
for i in range(0, len(valid_data), batch_size):
valid_model.train(False)
# Test time agumentation: Test model on regular and flipped data
regular_inputs = valid_data[i: i + batch_size]
logits = valid_model(regular_inputs).detach()
if use_TTA:
flipped_inputs = torch.flip(regular_inputs, [-1])
logits2 = valid_model(flipped_inputs).detach()
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
# Compute correct predictions
correct = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
valid_correct.append(correct.detach().type(torch.float64))
# Accuracy is average number of correct predictions
valid_acc = torch.mean(torch.cat(valid_correct)).item()
if valid_acc > best_acc:
best_acc = valid_acc
best_model = train_model
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}")
return best_acc, best_model
def main():
args = argparsing()
model_type = "resnet9"
cifar_dataset = args["dataset"]
save = args["save"]
weight_name = 'weights/%s_%s' % (model_type, cifar_dataset)
kwargs = {
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
"scale_out": 0.125
}
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
dtype = torch.float32
# Configurable parameters
ema_update_freq = 5
params = {
"dataset": cifar_dataset,
"epochs": args["epochs"],
"batch_size": args["batch"],
"momentum": 0.9,
"weight_decay": 0.256,
"weight_decay_bias": 0.004,
"ema_update_freq": ema_update_freq,
"ema_rho": 0.99 ** ema_update_freq,
"kwargs": kwargs,
"use_TTA": False
}
log = {
"weights": weight_name,
"model_type": model_type,
"kwargs": kwargs,
"params": params
}
nruns = args["nruns"]
accuracies = []
for run in range(nruns):
weight_name_seed = weight_name + "_run%d.pt" % run
best_acc, best_model = train(**params,
device=device,
dtype=dtype,
seed=run)
accuracies.append(best_acc)
print("Best Run Accuracy: %1.4f" % best_acc)
log["run%s" % run] = best_acc
if save:
print("Saving %s" % weight_name_seed)
torch.save(best_model.state_dict(), weight_name_seed)
mean = sum(accuracies) / len(accuracies)
variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies)
std = variance ** 0.5
print("Accuracy: %1.4f +/- %1.4f" % (mean, std))
log["accuracy"] = [mean, std]
if save:
with open("logs/logs_resnet9_%s.json" % cifar_dataset, 'w') as fp:
json.dump(log, fp)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,354 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import argparse
import copy
import json
import time
import torch.nn as nn
from optuna_params import get_optuna_params
from palisade_he_cnn.training.utils.utils_dataloading import random_crop
from palisade_he_cnn.training.utils.utils_kurtosis import *
from palisade_he_cnn.training.utils.utils_resnetN import (
get_model, update_nesterov, update_ema, label_smoothing_loss
)
# training time augmentation
use_TTA = False
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.max_accuracy = 0.0
def early_stop(self, accuracy):
if accuracy > self.max_accuracy:
self.max_accuracy = accuracy
self.counter = 0
elif accuracy <= (self.max_accuracy + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
def argparsing():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nlayers',
help='ResNet model depth',
type=int,
choices=[20, 32, 44, 56, 110],
required=True)
parser.add_argument('-bs', '--batch',
help='Batch size',
type=int,
required=False,
default=256)
parser.add_argument('-e', '--epochs',
help='Number of epochs',
type=int,
required=False,
default=100)
parser.add_argument('-d', '--debug',
help='Debugging mode',
type=bool,
required=False,
default=False)
parser.add_argument('-c', '--cuda',
help='CUDA device number',
type=int,
required=False,
default=0)
parser.add_argument('-dataset', '--dataset',
help='CIFAR10 or CIFAR100',
type=str,
choices=['CIFAR10', 'CIFAR100'],
required=False,
default='CIFAR10')
parser.add_argument('-s', '--save',
help='Save model and log files',
type=bool,
required=False,
default=True)
return vars(parser.parse_args())
def train(
dataset,
epochs,
batch_size,
lr,
lr_bias,
momentum,
weight_decay,
weight_decay_bias,
ema_update_freq,
ema_rho,
device,
dtype,
model_type,
kwargs,
seed=0
):
# Load dataset
if dataset == "CIFAR10":
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
else:
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
train_data = torch.cat(
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
valid_data = torch.cat(
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
N = int(len(train_data) / batch_size) # 50k / 256, now below is organized by epoch #
lr_schedule = torch.cat([
torch.linspace(0.0, lr, N),
# torch.linspace(lr, lr, 2*N),
torch.linspace(lr, 1e-4, 3 * N),
torch.linspace(1e-4, 1e-4, 50 * N),
torch.linspace(1e-5, 1e-5, 25 * N),
torch.linspace(1e-6, 1e-6, 25 * N),
])
lr_schedule_bias = lr_bias * lr_schedule
kurt_schedule = torch.cat([
torch.linspace(0, 0, 10 * N),
torch.linspace(0.05, 0.05, 2 * N),
torch.linspace(0.1, 0.1, 200 * N),
])
# Print information about hardware on first run
if seed == 0:
if device.type == "cuda":
print("Device :", torch.cuda.get_device_name(device.index))
print("Dtype :", dtype)
print()
# Start measuring time
start_time = time.perf_counter()
# Set random seed to increase chance of reproducability
torch.manual_seed(seed)
# Setting cudnn.benchmark to True hampers reproducability, but is faster
torch.backends.cudnn.benchmark = True
# Convert model weights to half precision
train_model = get_model(model_type, kwargs).to(device)
train_model.to(dtype)
# Convert BatchNorm back to single precision for better accuracy
for module in train_model.modules():
if isinstance(module, nn.BatchNorm2d):
module.float()
# Collect weights and biases and create nesterov velocity values
weights = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) > 1
]
biases = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) <= 1
]
# Copy the model for validation
valid_model = copy.deepcopy(train_model)
# Patience:
early_stopper = EarlyStopper(patience=120, min_delta=0.001) # this is %
# Testing non-SGD optimizer
optimizer = torch.optim.Adam(train_model.parameters(), lr=0.001)
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
print("\nepoch batch train time [sec] validation accuracy")
train_time = 0.0
batch_count = 0
best_acc = 0.0
best_model = None
for epoch in range(1, epochs + 1):
start_time = time.perf_counter()
# Randomly shuffle training data
indices = torch.randperm(len(train_data), device=device)
data = train_data[indices]
targets = train_targets[indices]
# Crop random 32x32 patches from 40x40 training data
data = [
random_crop(data[i: i + batch_size], crop_size=(32, 32))
for i in range(0, len(data), batch_size)
]
data = torch.cat(data)
# Randomly flip half the training data
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
for i in range(0, len(data), batch_size):
# Discard partial batches
if i + batch_size > len(data):
break
# Slice batch from data
inputs = data[i: i + batch_size]
target = targets[i: i + batch_size]
batch_count += 1
# Compute new gradients
train_model.zero_grad()
train_model.train(True)
# kurtosis setup
remove_all_hooks(train_model)
activations = get_intermediate_output(train_model)
logits = train_model(inputs)
loss = label_smoothing_loss(logits, target, alpha=0.2)
# kurtosis scheduler
kurt_index = min(batch_count, len(kurt_schedule) - 1)
kurt_scale = kurt_schedule[kurt_index]
# kurtosis calculation and cleanup
remove_all_hooks(train_model)
activations = {k: torch.cat(v) for k, v in activations.items()}
loss_means, loss_stds, loss_kurts = get_statistics(activations)
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
loss.sum().backward()
lr_index = min(batch_count, len(lr_schedule) - 1)
lr = lr_schedule[lr_index]
lr_bias = lr_schedule_bias[lr_index]
# Update weights and biases of training model
update_nesterov(weights, lr, weight_decay, momentum)
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
# Update validation model with exponential moving averages
if (i // batch_size % ema_update_freq) == 0:
update_ema(train_model, valid_model, ema_rho)
# Add training time
train_time += time.perf_counter() - start_time
valid_correct = []
for i in range(0, len(valid_data), batch_size):
valid_model.train(False)
regular_inputs = valid_data[i: i + batch_size]
logits = valid_model(regular_inputs).detach()
if use_TTA:
flipped_inputs = torch.flip(regular_inputs, [-1])
logits2 = valid_model(flipped_inputs).detach()
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
# Compute correct predictions
correct = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
valid_correct.append(correct.detach().type(torch.float64))
# Accuracy is average number of correct predictions
valid_acc = torch.mean(torch.cat(valid_correct)).item()
if valid_acc > best_acc:
best_acc = valid_acc
best_model = train_model
if early_stopper.early_stop(valid_acc):
print("Early stopping")
break
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}")
return best_acc, best_model
def main():
args = argparsing()
model_type = "resnet%s" % args["nlayers"]
cifar_dataset = args["dataset"]
save = args["save"]
weight_name = 'weights/%s_%s' % (model_type, cifar_dataset)
print("ResNet%s" % args["nlayers"])
print("Weight file:", weight_name)
kwargs = {
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
"debug": args["debug"]
}
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
dtype = torch.float32
# Optuna:
optuna_params = get_optuna_params(model_type, cifar_dataset)
lr = optuna_params["lr"]
lr_bias = optuna_params["lr_bias"]
momentum = optuna_params["momentum"]
weight_decay = optuna_params["weight_decay"]
# Configurable parameters
ema_update_freq = 5
params = {
"dataset": cifar_dataset,
"epochs": args["epochs"],
"batch_size": args["batch"],
"lr": lr,
"lr_bias": lr_bias,
"momentum": momentum,
"weight_decay": weight_decay,
"weight_decay_bias": 0.004,
"ema_update_freq": ema_update_freq,
"ema_rho": 0.99 ** ema_update_freq,
"model_type": model_type,
"kwargs": kwargs
}
nruns = 5
log = {
"weights": weight_name,
"model_type": model_type,
"kwargs": kwargs,
"params": params
}
accuracies = []
for run in range(nruns):
weight_name_seed = weight_name + "_run%d.pt" % run
best_acc, best_model = train(**params,
device=device,
dtype=dtype,
seed=run)
accuracies.append(best_acc)
print("Best Run Accuracy: %1.4f" % best_acc)
log["run%s" % run] = best_acc
if save:
print("Saving %s" % weight_name_seed)
torch.save(best_model.state_dict(), weight_name_seed)
mean = sum(accuracies) / len(accuracies)
variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies)
std = variance ** 0.5
print("Accuracy: %1.4f +/- %1.4f" % (mean, std))
log["accuracy"] = [mean, std]
if save:
with open("logs/logs_resnet%s_%s.json" % (args["nlayers"], cifar_dataset), 'w') as fp:
json.dump(log, fp)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,326 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import argparse
import copy
import time
import optuna
import joblib
from optuna.trial import TrialState
from palisade_he_cnn.training.utils.utils_dataloading import *
from palisade_he_cnn.training.utils.utils_kurtosis import *
from palisade_he_cnn.training.utils.utils_resnetN import (
get_model, update_nesterov, update_ema, label_smoothing_loss
)
use_TTA = False
def argparsing():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nlayers',
help='ResNet model depth',
type=int,
choices=[20, 32, 44, 56, 110],
required=True)
parser.add_argument('-bs', '--batch',
help='Batch size',
type=int,
required=False,
default=256)
parser.add_argument('-e', '--epochs',
help='Number of epochs',
type=int,
required=False,
default=100)
parser.add_argument('-d', '--debug',
help='Debugging mode',
type=bool,
required=False,
default=False)
parser.add_argument('-c', '--cuda',
help='CUDA device number',
type=int,
required=False,
default=0)
parser.add_argument('-dataset', '--dataset',
help='CIFAR10 or CIFAR100',
type=str,
choices=['CIFAR10', 'CIFAR100'],
required=False,
default='CIFAR10')
parser.add_argument('-s', '--save',
help='Save model and log files',
type=bool,
required=False,
default=True)
return vars(parser.parse_args())
def train(
trial,
dataset,
epochs,
batch_size,
weight_decay_bias,
ema_update_freq,
ema_rho,
device,
dtype,
model_type,
kwargs,
seed=0
):
# Print information about hardware on first run
if seed == 0:
if device.type == "cuda":
print("Device :", torch.cuda.get_device_name(device.index))
print("Dtype :", dtype)
print()
# Start measuring time
start_time = time.perf_counter()
# Set random seed to increase chance of reproducability
torch.manual_seed(seed)
# Load dataset
if dataset == "CIFAR10":
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
else:
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
train_data = torch.cat(
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
valid_data = torch.cat(
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
# Convert model weights to half precision
train_model = get_model(model_type, kwargs).to(device)
train_model.to(dtype)
train_model.train()
# Generate the optimizers.
lr = trial.suggest_float("lr", 9e-4, 2e-3)
lr_bias = trial.suggest_float("lr_bias", 54, 74)
momentum = trial.suggest_float("momentum", 0.7, .99)
weight_decay = trial.suggest_float("weight_decay", 0.01, .3)
N = int(len(train_data) / batch_size)
lr_schedule = torch.cat([
torch.linspace(0.0, lr, N),
# torch.linspace(lr, lr, 2*N),
torch.linspace(lr, 1e-4, 3 * N),
torch.linspace(1e-4, 1e-4, 50 * N),
torch.linspace(1e-5, 1e-5, 25 * N),
torch.linspace(1e-6, 1e-6, 25 * N),
])
lr_schedule_bias = lr_bias * lr_schedule
kurt_schedule = torch.cat([
torch.linspace(0, 0, 10 * N),
torch.linspace(0.05, 0.05, 2 * N),
torch.linspace(0.1, 0.1, 200 * N),
])
# Convert BatchNorm back to single precision for better accuracy
for module in train_model.modules():
if isinstance(module, nn.BatchNorm2d):
module.float()
# Collect weights and biases and create nesterov velocity values
weights = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) > 1
]
biases = [
(w, torch.zeros_like(w))
for w in train_model.parameters()
if w.requires_grad and len(w.shape) <= 1
]
# Copy the model for validation
valid_model = copy.deepcopy(train_model)
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
# Train and validate
print("\nepoch batch train time [sec] validation accuracy")
train_time = 0.0
batch_count = 0
best_acc = []
for epoch in range(1, epochs + 1):
start_time = time.perf_counter()
# Randomly shuffle training data
indices = torch.randperm(len(train_data), device=device)
data = train_data[indices]
targets = train_targets[indices]
# Crop random 32x32 patches from 40x40 training data
data = [
random_crop(data[i: i + batch_size], crop_size=(32, 32))
for i in range(0, len(data), batch_size)
]
data = torch.cat(data)
# Randomly flip half the training data
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
loss_epoch = 0.0
for i in range(0, len(data), batch_size):
# discard partial batches
if i + batch_size > len(data):
break
# Slice batch from data
inputs = data[i: i + batch_size]
target = targets[i: i + batch_size]
batch_count += 1
# Compute new gradients
train_model.zero_grad()
train_model.train(True)
# kurtosis setup
remove_all_hooks(train_model)
activations = get_intermediate_output(train_model)
logits = train_model(inputs)
loss = label_smoothing_loss(logits, target, alpha=0.2)
# kurtosis scheduler
kurt_index = min(batch_count, len(kurt_schedule) - 1)
kurt_scale = kurt_schedule[kurt_index]
# kurtosis calculation and cleanup
remove_all_hooks(train_model)
activations = {k: torch.cat(v) for k, v in activations.items()}
loss_means, loss_stds, loss_kurts = get_statistics(activations)
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
loss.sum().backward()
loss_epoch += loss.sum().item() / batch_size
lr_index = min(batch_count, len(lr_schedule) - 1)
lr = lr_schedule[lr_index]
lr_bias = lr_schedule_bias[lr_index]
# Update weights and biases of training model
update_nesterov(weights, lr, weight_decay, momentum)
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
# Update validation model with exponential moving averages
if (i // batch_size % ema_update_freq) == 0:
update_ema(train_model, valid_model, ema_rho)
# Add training time
train_time += time.perf_counter() - start_time
correct = []
with torch.no_grad():
for i in range(0, len(valid_data), batch_size):
valid_model.train(False)
regular_inputs = valid_data[i: i + batch_size]
logits = valid_model(regular_inputs).detach()
if use_TTA:
flipped_inputs = torch.flip(regular_inputs, [-1])
logits2 = valid_model(flipped_inputs).detach()
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
# Compute correct predictions
temp = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
correct.append(temp.detach().type(torch.float64))
# Accuracy is average number of correct predictions
accuracy = torch.mean(torch.cat(correct)).item()
best_acc.append(accuracy)
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {accuracy:22.4f}")
trial.report(accuracy, epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return max(best_acc)
def objective(trial):
args = argparsing()
model_type = "resnet%s" % args["nlayers"]
cifar_dataset = args["dataset"]
save = args["save"]
weight_name = 'weights/optuna/%s_%s' % (model_type, cifar_dataset)
print("ResNet%s" % args["nlayers"])
print("Weight file:", weight_name)
kwargs = {
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
"debug": args["debug"]
}
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
dtype = torch.float32
ema_update_freq = 5
params = {
"trial": trial,
"dataset": cifar_dataset,
"epochs": args["epochs"],
"batch_size": args["batch"],
"weight_decay_bias": 0.004,
"ema_update_freq": ema_update_freq,
"ema_rho": 0.99 ** ema_update_freq,
"model_type": model_type,
"kwargs": kwargs
}
accuracy = train(**params,
device=device,
dtype=dtype,
seed=0)
return accuracy
def main():
args = argparsing()
model_type = "resnet%s" % args["nlayers"]
cifar_dataset = args["dataset"]
save = args["save"]
study = optuna.create_study(direction="maximize",
storage="sqlite:///db.sqlite3",
pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=5,
min_delta=0.0))
study.optimize(objective, n_trials=10)
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
print("Study statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
print("Best trial:")
trial = study.best_trial
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
joblib.dump(study, "study_%s_%s.pkl" % (model_type, cifar_dataset))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,76 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import subprocess
import argparse
#layer = [20,32,44,56,110]
#batch = [256,256,256,256,64]
#epoch = [200,200,200,200,200]
#layer = [32]
#batch = [256]
#epoch = [200]
def argparsing():
parser = argparse.ArgumentParser()
parser.add_argument('-c','--cuda',
help='CUDA device number',
type=int,
required=False,
default=0)
parser.add_argument('-dataset', '--dataset',
help='CIFAR10 or CIFAR100',
type=str,
choices=['CIFAR10','CIFAR100'],
required=False,
default='CIFAR10')
parser.add_argument('-n', '--nlayers',
help='List of layers in string format',
type=str,
required=True,
default='[20,32,44,56,110]')
parser.add_argument('-e', '--epochs',
help='List of epochs in string format',
type=str,
required=True,
default='[100,100,100,100,100]')
parser.add_argument('-bs', '--batch_size',
help='List of batch sizes in string format',
type=str,
required=True,
default='[256,256,256,256,64]')
return vars(parser.parse_args())
def str2list(arg):
return [int(item) for item in arg.split(',') if item!='']
def main():
args = argparsing()
layer = str2list(args["nlayers"])
batch = str2list(args["epochs"])
epoch = str2list(args["batch_size"])
dataset = args["dataset"]
cuda = args["cuda"]
for i in range(len(layer)):
cmd = "python3 train_resnetN.py -n %s -bs %s -e %s -dataset %s -c %d" \
% (layer[i], batch[i], epoch[i], dataset, cuda)
print("\n")
print(cmd)
subprocess.run(
[
"python3",
"train_resnetN.py",
"-n", "%s"%str(layer[i]),
"-bs", "%s"%str(batch[i]),
"-e", "%s"%str(epoch[i]),
"-c", "%s"%str(cuda),
"-dataset", dataset
],
shell=False)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,52 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import subprocess
import argparse
layer = [20, 32, 44, 56, 110]
batch = [256, 256, 256, 256, 64]
epoch = [50, 50, 50, 50, 50]
def argparsing():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--cuda',
help='CUDA device number',
type=int,
required=False,
default=0)
parser.add_argument('-dataset', '--dataset',
help='CIFAR10 or CIFAR100',
type=str,
choices=['CIFAR10', 'CIFAR100'],
required=False,
default='CIFAR10')
return vars(parser.parse_args())
def main():
args = argparsing()
dataset = args["dataset"]
cuda = args["cuda"]
for i in range(len(layer)):
cmd = "python3 train_resnetN_optuna.py -n %s -bs %s -e %s -dataset %s -c %d" \
% (layer[i], batch[i], epoch[i], dataset, cuda)
print("\n")
print(cmd)
subprocess.run(
[
"python3",
"train_resnetN_optuna.py",
"-n", "%s" % str(layer[i]),
"-bs", "%s" % str(batch[i]),
"-e", "%s" % str(epoch[i]),
"-c", "%s" % str(cuda),
"-dataset", dataset
],
shell=False)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,146 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import ast
from collections import OrderedDict, defaultdict
from typing import Dict, Callable
import torch
import torchvision.transforms as tt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
class PadChannel(object):
def __init__(self, npad: int=1):
self.n = npad
def __call__(self, x):
_, width, height = x.shape
x = torch.cat([x, torch.zeros(self.n, width, height)])
return x
def get_gelu_poly_coeffs(degree, filename='gelu_poly_approx_params.txt'):
with open(filename, 'r') as fp:
params = []
for line in fp:
x = line[:-1]
params.append(ast.literal_eval(x))
if degree==2:
return params[0]
if degree==4:
return params[1]
elif degree==8:
return params[2]
elif degree==16:
return params[3]
elif degree==32:
return params[4]
else:
print("Defaulting to deg8")
return params[2]
def patch_whitening(data, patch_size=(3, 3)):
# Compute weights from data such that
# torch.std(F.conv2d(data, weights), dim=(2, 3))
# is close to 1.
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
def get_cifar10_dataloader(batch_size,
data_dir: str='../../datasets/cifar10/',
num_workers: int=4):
stats = ((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
train_tfms = tt.Compose([
tt.RandomCrop(32,padding=4,padding_mode='reflect'),
tt.RandomHorizontalFlip(),
tt.ToTensor(),
tt.Normalize(*stats,inplace=True),
PadChannel(npad=1)
])
val_tfms = tt.Compose([
tt.ToTensor(),
tt.Normalize(*stats,inplace=True),
PadChannel(npad=1)
])
train_ds = ImageFolder(data_dir+'train',transform=train_tfms)
val_ds = ImageFolder(data_dir+'test',transform=val_tfms)
train_dl = DataLoader(train_ds,
batch_size,
pin_memory = True,
num_workers = num_workers,
shuffle = True)
val_dl = DataLoader(val_ds,
batch_size,
pin_memory = True,
num_workers = num_workers)
return train_dl, val_dl
def remove_all_hooks(model: torch.nn.Module) -> None:
for name, child in model._modules.items():
if child is not None:
if hasattr(child, "_forward_hooks"):
child._forward_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_forward_pre_hooks"):
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_backward_hooks"):
child._backward_hooks: Dict[int, Callable] = OrderedDict()
remove_all_hooks(child)
# Given a model and an input, get intermediate layer output
def get_intermediate_output(model):
activation = defaultdict(list)
def get_activation(name):
def hook(model, input, output):
x = output.detach().cpu()
activation[name].append(x)
return hook
BatchNorm_layers = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
for i, b in enumerate(BatchNorm_layers):
b.register_forward_hook(
get_activation(f"bn_{i + 1}")
)
return activation
def get_all_bn_activations(model, val_dl, DEVICE):
activation = get_intermediate_output(model)
model.to(DEVICE)
model.eval()
for img, label in (val_dl):
img, label = img.to(DEVICE), label.to(DEVICE)
out = model(img)
remove_all_hooks(model)
activation = {k:torch.cat(v) for k,v in activation.items()}
return activation

View File

@@ -0,0 +1,75 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import torch
import torch.nn as nn
import torchvision
def load_cifar10(device, dtype, data_dir='./datasets/cifar10/'):
print("Loading CIFAR10")
train = torchvision.datasets.CIFAR10(root=data_dir, download=True)
valid = torchvision.datasets.CIFAR10(root=data_dir, train=False)
train_data = preprocess_cifar10_data(train.data, device, dtype)
valid_data = preprocess_cifar10_data(valid.data, device, dtype)
train_targets = torch.tensor(train.targets).to(device)
valid_targets = torch.tensor(valid.targets).to(device)
# Pad 32x32 to 40x40
train_data = nn.ReflectionPad2d(4)(train_data)
return train_data, train_targets, valid_data, valid_targets
def load_cifar100(device, dtype, data_dir='./datasets/cifar100/'):
print("Loading CIFAR100")
train = torchvision.datasets.CIFAR100(root=data_dir, download=True)
valid = torchvision.datasets.CIFAR100(root=data_dir, train=False)
train_data = preprocess_cifar100_data(train.data, device, dtype)
valid_data = preprocess_cifar100_data(valid.data, device, dtype)
train_targets = torch.tensor(train.targets).to(device)
valid_targets = torch.tensor(valid.targets).to(device)
# Pad 32x32 to 40x40
train_data = nn.ReflectionPad2d(4)(train_data)
return train_data, train_targets, valid_data, valid_targets
def random_crop(data, crop_size):
crop_h, crop_w = crop_size
h = data.size(2)
w = data.size(3)
x = torch.randint(w - crop_w, size=(1,))[0]
y = torch.randint(h - crop_h, size=(1,))[0]
return data[:, :, y : y + crop_h, x : x + crop_w]
def preprocess_cifar10_data(data, device, dtype):
# Convert to torch float16 tensor
data = torch.tensor(data, device=device).to(dtype)
# Normalize
mean = torch.tensor([125.31, 122.95, 113.87], device=device).to(dtype)
std = torch.tensor([62.99, 62.09, 66.70], device=device).to(dtype)
data = (data - mean) / std
# Permute data from NHWC to NCHW format
data = data.permute(0, 3, 1, 2)
return data
def preprocess_cifar100_data(data, device, dtype):
# Convert to torch float16 tensor
data = torch.tensor(data, device=device).to(dtype)
# Normalize
mean = torch.tensor([129.30, 124.07, 112.43], device=device).to(dtype)
std = torch.tensor([68.17, 65.39, 70.42], device=device).to(dtype)
data = (data - mean) / std
# Permute data from NHWC to NCHW format
data = data.permute(0, 3, 1, 2)
return data

View File

@@ -0,0 +1,63 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
from collections import defaultdict, OrderedDict
from typing import Dict, Callable
import torch
import torch.nn.functional as F
def get_intermediate_output(model):
activations = defaultdict(list)
def get_activation(name):
def hook(model, input, output):
x = input[0]
activations[name].append(x)
return hook
GELU_layers = [m for m in model.modules() if isinstance(m, torch.nn.GELU)]
for i, b in enumerate(GELU_layers):
b.register_forward_hook(
get_activation(f"GELU_{i + 1}")
)
return activations
def remove_all_hooks(model: torch.nn.Module) -> None:
for name, child in model._modules.items():
if child is not None:
if hasattr(child, "_forward_hooks"):
child._forward_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_forward_pre_hooks"):
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
elif hasattr(child, "_backward_hooks"):
child._backward_hooks: Dict[int, Callable] = OrderedDict()
remove_all_hooks(child)
def moment(x: torch.Tensor, std: float, mean: float, deg: int=4, eps: float=1e-4) -> torch.Tensor:
x = x.double()
temp = (x-mean)**deg / x.shape[0]
return torch.sum(temp) / (std**deg + eps)
def get_statistics(activations):
n = len(activations)
means = torch.zeros(n)
stds = torch.zeros(n)
kurts = torch.zeros(n)
for layer_index,name in enumerate(sorted(activations.keys(), key=lambda x:int(x.split('_')[1]))):
dist = activations[name]
dist = dist.flatten()
std, mean = torch.std_mean(dist)
kurt = moment(dist, std, mean, deg=4)
means[layer_index] = mean
stds[layer_index] = std
kurts[layer_index] = kurt
loss_means = F.mse_loss(means, torch.zeros(n))
loss_stds = F.mse_loss(stds, torch.ones(n))
loss_kurts = F.mse_loss(kurts, 3*torch.ones(n))
return loss_means, loss_stds, loss_kurts

View File

@@ -0,0 +1,95 @@
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
import numpy as np
import torch
import json
import glob
from palisade_he_cnn.training.models.resnetN_multiplexed import *
def get_model(model_type, kwargs):
if model_type=='resnet20':
return resnet20(**kwargs)
elif model_type=='resnet32':
return resnet32(**kwargs)
elif model_type=='resnet44':
return resnet44(**kwargs)
elif model_type=='resnet56':
return resnet56(**kwargs)
elif model_type=='resnet110':
return resnet110(**kwargs)
elif model_type=='resnet_test':
return resnet_test(**kwargs)
else:
raise ValueError("Returning None bc you are wrong!")
def get_best_weights(loc, dataset, model_type):
loc = '%s%s/' % (loc,dataset)
log_file = None
for log in glob.glob(loc+'logs/*.json'):
if model_type in log:
log_file = log
break
if log_file is None:
raise ValueError("model_type number must be resnet9,20,32,44,56, or 110")
with open(log_file) as f:
contents = json.load(f)
print("Finding the best model according to logs...")
print(contents)
runs = {"run%d"%i : contents["run%d"%i] for i in range(5)}
mean, std = contents["accuracy"]
accs = [contents["run%d"%i] for i in range(5)]
idx = accs.index(max(accs))
print("\nAverage (5 runs): %1.3f%% +/- %1.3f%%" % (100*mean, 100*std))
print("Best (idx %d): %1.3f" % (idx,runs["run%d"%idx]))
weight_file = loc + "weights/%s_%s_run%d.pt" % (model_type, dataset, idx)
return weight_file
def num_params(model) -> int:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
def update_ema(train_model, valid_model, rho):
# The trained model is not used for validation directly. Instead, the
# validation model weights are updated with exponential moving averages.
train_weights = train_model.state_dict().values()
valid_weights = valid_model.state_dict().values()
for train_weight, valid_weight in zip(train_weights, valid_weights):
if valid_weight.dtype in [torch.float16, torch.float32]:
valid_weight *= rho
valid_weight += (1 - rho) * train_weight
def update_nesterov(weights, lr, weight_decay, momentum):
for weight, velocity in weights:
if weight.requires_grad:
gradient = weight.grad.data
weight = weight.data
gradient.add_(weight, alpha=weight_decay).mul_(-lr)
velocity.mul_(momentum).add_(gradient)
weight.add_(gradient.add_(velocity, alpha=momentum))
def label_smoothing_loss(inputs, targets, alpha):
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
kl = -log_probs.mean(dim=1)
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
loss = (1 - alpha) * xent + alpha * kl
return loss
def patch_whitening(data, patch_size=(3, 3)):
h, w = patch_size
c = data.size(1)
patches = data.unfold(2, h, 1).unfold(3, w, 1)
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
n, c, h, w = patches.shape
X = patches.reshape(n, c * h * w)
X = X / (X.size(0) - 1) ** 0.5
covariance = X.t() @ X
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
eigenvalues = eigenvalues.flip(0)
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)

1069
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

29
pyproject.toml Normal file
View File

@@ -0,0 +1,29 @@
[tool.poetry]
name = "palisade_he_cnn"
version = "0.1.0"
description = ""
authors = [
"Vikram Saraph <vikram.saraph@jhuapl.edu>",
"Vivian Maloney <vivian.maloney@jhuapl.edu>",
"Freddy Obrecht <freddy.obrecht@jhuapl.edu>",
"Kate Tallaksen <kate.tallaksen@jhuapl.edu>",
"Prathibha Rama <prathibha.rama@jhuapl.edu>"
]
readme = "README.md"
packages = [
{include = "palisade_he_cnn"}
]
[tool.poetry.dependencies]
python = "^3.10"
torch = "^1.13.1"
torchvision = "^0.14.1"
optuna = "^3.2.0"
joblib = "^1.3.2"
pytest = "^7.4.0"
openfhe = {path = "../palisade-python/wheelhouse/OpenFHE-1.0.5a0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"}
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"