mirror of
https://github.com/JHUAPL/SHIELD.git
synced 2026-01-07 22:03:53 -05:00
Initial commit
This commit is contained in:
31
LICENSE
Normal file
31
LICENSE
Normal file
@@ -0,0 +1,31 @@
|
||||
(c) 2021-2023 The Johns Hopkins University Applied Physics
|
||||
Laboratory LLC (JHU/APL).
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions
|
||||
are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following
|
||||
disclaimer in the documentation and/or other materials provided
|
||||
with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived
|
||||
from this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
|
||||
FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
|
||||
COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
|
||||
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
|
||||
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||||
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
|
||||
OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
21
LICENSES/Optuna/LICENSE
Normal file
21
LICENSES/Optuna/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2018 Preferred Networks, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
80
LICENSES/PyTorch/LICENSE
Normal file
80
LICENSES/PyTorch/LICENSE
Normal file
@@ -0,0 +1,80 @@
|
||||
From PyTorch:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
From Caffe2:
|
||||
|
||||
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Kakao Brain:
|
||||
Copyright 2019-2020 Kakao Brain
|
||||
|
||||
All contributions by Cruise LLC:
|
||||
Copyright (c) 2022 Cruise LLC.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
mark their specific copyright on a particular contribution, they should
|
||||
indicate their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
76
README.md
Normal file
76
README.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# SHIELD: Secure Homomorphic Inference for Encrypted Learning on Data
|
||||
|
||||
SHIELD is a library for evalating pre-trained convolutional neural networks on homomorphically encrypted images. It includes code for training models that are suitable for homomorphic evaluation. Implemented neural network operations include convolution, average pooling, GELU, and linear layers.
|
||||
|
||||
This code was used to run the experiments supporting the following paper: [High-Resolution Convolutional Neural Networks on Homomorphically Encrypted Data via Sharding Ciphertexts
|
||||
](https://arxiv.org/abs/2306.09189). However, operators defined in this project are generic enough to build arbitrary convolutional neural networks as specified in the paper.
|
||||
|
||||
## Requirements
|
||||
This project's dependencies are managed by Poetry, so installing [Poetry](https://python-poetry.org/) is a requirement. OpenFHE Python bindings are used to interface with OpenFHE, so the wheel file for these bindings will also need to be built. See the OpenFHE Python bindings repository for further instructions.
|
||||
|
||||
Once the bindings are builts, ensure that the `pyproject.toml` file contains a correct path to the bindings. Then to install the Python environment for this project, run `poetry install`. For running unit tests and the small neural network as described below, 32GB of RAM is recommended. For hardware requirements needed to reproduce results for the larger ResNet architecures, see the paper for details.
|
||||
|
||||
Code was developed and tested on Ubuntu 20.04. While it should run on Windows platforms as well, this has not been explicitly tested.
|
||||
|
||||
## Features
|
||||
|
||||
SHIELD implements the following neural network operators:
|
||||
|
||||
- Convolution
|
||||
- Average pooling
|
||||
- Batch normalization (which are fused with convolution operators for performance)
|
||||
- Linear
|
||||
- GELU (Gaussian Error Linear Unit, a smooth alternative to ReLU)
|
||||
- Upsample
|
||||
|
||||
For performance reasons, the core of these algorithms are mostly implemented in the companion OpenFHE Python bindings project (in C++), with this project providing a minimal but more user-friendly Python inference for using them.
|
||||
|
||||
The following neural network architectures are implemented using homomorphic implementations of the above operators: a neural network consisting of three convolution blocks (mainly for integration testing), and variations on ResNet including ResNet9 and ResNet50. In addition, code for training models suitable for homomorphic evaluation, using these architectures is included. Training code includes kurtosis regularization required for homomorphic inference. See the referenced paper for more details on the algorithms implemented, as well as performance metrics for homomorphic inference using these neural networks.
|
||||
|
||||
## Running the code
|
||||
|
||||
### Units tests
|
||||
|
||||
Tests are run with `pytest`:
|
||||
|
||||
```
|
||||
poetry run python palisade_he_cnn/test.py
|
||||
```
|
||||
|
||||
### A small neural network
|
||||
|
||||
`small_model.py` includes code defining a 3-layer convolutional neural network, as well as code to train a model, on MNIST, instantiated from this network. The training code can be run with:
|
||||
|
||||
```
|
||||
poetry run python palisade_he_cnn/src/small_model.py
|
||||
```
|
||||
|
||||
This will save model weights to `small_model.pt`. To run homomorphic inference with these weights, move the weights to `palisade_he_cnn/src/weights/` and then run:
|
||||
|
||||
```
|
||||
poetry run python palisade_he_cnn/src/small_model_inference.py
|
||||
```
|
||||
|
||||
This script builds an equivalent homomorphic architecture, extracting weights from the plaintext model, and runs inference on MNIST. It prints out inference times to the terminal. For convenience, example weights are already included in `palisade_he_cnn/src/weights`.
|
||||
|
||||
### Larger neural networks
|
||||
|
||||
Scripts to train larger models are included in `palisade_he_cnn/training`. Scripts that run inference with these models are in `palisade_he_cnn/inference`. Due to significant resources required to train and run homomorphic inference with these larger models, weights used in the paper will be added to this repository in the future.
|
||||
|
||||
## Citation and Acknowledgements
|
||||
|
||||
Please cite this work as follows:
|
||||
|
||||
```
|
||||
@misc{maloney2024highresolutionconvolutionalneuralnetworks,
|
||||
title={High-Resolution Convolutional Neural Networks on Homomorphically Encrypted Data via Sharding Ciphertexts},
|
||||
author={Vivian Maloney and Richard F. Obrecht and Vikram Saraph and Prathibha Rama and Kate Tallaksen},
|
||||
year={2024},
|
||||
eprint={2306.09189},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CR},
|
||||
url={https://arxiv.org/abs/2306.09189},
|
||||
}
|
||||
```
|
||||
|
||||
In addition to the authors on the supporting manuscript (Vivian Maloney, Freddy Obrect, Vikram Saraph, Prathibha Rama, and Kate Tallaksen), Lindsay Spriggs and Court Climer also contributed to this work by testing the software and integrating it with internal infrastructure.
|
||||
132
palisade_he_cnn/inference/resnet50_cifar_inference.py
Normal file
132
palisade_he_cnn/inference/resnet50_cifar_inference.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
|
||||
from palisade_he_cnn.src.he_cnn.utils import *
|
||||
from palisade_he_cnn.src.utils import pad_conv_input_channels, PadChannel
|
||||
|
||||
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--idx", default="0")
|
||||
args = vars(parser.parse_args())
|
||||
img_idx = int(args["idx"])
|
||||
|
||||
print("img_idx", img_idx)
|
||||
|
||||
# create HE cc and keys
|
||||
mult_depth = 35
|
||||
scale_factor_bits = 59
|
||||
batch_size = 32 * 32 * 32
|
||||
|
||||
|
||||
# if using bootstrapping, you must increase scale_factor_bits to 59
|
||||
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
|
||||
|
||||
|
||||
stats = ((0.4914, 0.4822, 0.4465), # mean
|
||||
(0.247, 0.243, 0.261)) # std
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1),
|
||||
transforms.Resize(32)
|
||||
])
|
||||
validset = torchvision.datasets.CIFAR10(root="./data", download=True, transform=transform)
|
||||
validloader = torch.utils.data.DataLoader(validset, batch_size=1, shuffle=True)
|
||||
|
||||
# top level model
|
||||
resnet_model = torch.load("palisade_he_cnn/src/weights/resnet50_cifar_gelu_kurt.pt")
|
||||
resnet_model.eval()
|
||||
|
||||
print(resnet_model)
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
conv1 = resnet_model.conv1
|
||||
bn1 = resnet_model.bn1
|
||||
|
||||
padded_conv1 = pad_conv_input_channels(conv1)
|
||||
|
||||
embedder = copy.deepcopy(torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
|
||||
|
||||
for i, (padded_test_data, test_label) in enumerate(validloader):
|
||||
if i == img_idx:
|
||||
break
|
||||
|
||||
unpadded_test_data = padded_test_data[:,:3]
|
||||
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
|
||||
|
||||
start = time()
|
||||
|
||||
# embedding layer
|
||||
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
|
||||
cnn_context = cnn_context.apply_gelu(bound=15.0)
|
||||
|
||||
unencrypted = ptxt_embedded
|
||||
|
||||
compare_accuracy(keys, cnn_context, unencrypted, "embedding", num_digits=7)
|
||||
|
||||
###############################################################################
|
||||
|
||||
|
||||
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
|
||||
for j, bottleneck in enumerate(layer):
|
||||
|
||||
bootstrap = False if (i == 0 and j == 0) else True
|
||||
name = f"bottleneck #{i+1}-{j}"
|
||||
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=bootstrap, bootstrap_params={"meta" : True})
|
||||
unencrypted = bottleneck(unencrypted)
|
||||
compare_accuracy(keys, cnn_context, unencrypted, name, num_digits=7)
|
||||
|
||||
###############################################################################
|
||||
|
||||
linear = resnet_model.fc
|
||||
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
|
||||
|
||||
inference_time = time() - start
|
||||
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
|
||||
|
||||
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
|
||||
ptxt_logits = linear(flattened)
|
||||
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
|
||||
|
||||
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
|
||||
|
||||
print(f"[+] decrypted logits = {decrypted_logits}")
|
||||
print(f"[+] plaintext logits = {ptxt_logits}")
|
||||
|
||||
###############################################################################
|
||||
|
||||
dataset = "cifar10"
|
||||
model_type = "resnet50_metaBTS"
|
||||
|
||||
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
|
||||
filename.parent.mkdir(exist_ok=True, parents=True)
|
||||
data = dict(TIMING_DICT)
|
||||
data["decrypted logits"] = decrypted_logits.tolist()
|
||||
data["unencrypted logits"] = ptxt_logits.tolist()
|
||||
data["inference time"] = inference_time
|
||||
|
||||
# avoid double-counting the strided conv operations
|
||||
data['Pool'] = data['Pool'][1:2]
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
159
palisade_he_cnn/inference/resnet50_imagenet128_inference.py
Normal file
159
palisade_he_cnn/inference/resnet50_imagenet128_inference.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
# srun -p hybrid -n 128 --mem=300G --pty bash -i
|
||||
# srun -p himem -n 128 --mem=300G --pty bash -i
|
||||
|
||||
# export OMP_DISPLAY_ENV=TRUE
|
||||
# export OMP_NUM_THREADS=32
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from time import time
|
||||
import copy
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from palisade_he_cnn.src.he_cnn.utils import compare_accuracy, get_keys
|
||||
from palisade_he_cnn.src.utils import pad_conv_input_channels
|
||||
from palisade_he_cnn.training.utils.utils import PadChannel
|
||||
|
||||
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--idx", default="0")
|
||||
args = vars(parser.parse_args())
|
||||
img_idx = int(args["idx"])
|
||||
|
||||
print("img_idx", img_idx)
|
||||
|
||||
IMAGENET_CHANNEL_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_CHANNEL_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
stats = (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD)
|
||||
|
||||
IMAGENET_DIR = Path("/aoscluster/he-cnn/vivian/imagenet/datasets/ILSVRC/Data/CLS-LOC")
|
||||
resize_size = 136
|
||||
crop_size = 128
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1),
|
||||
transforms.Resize(resize_size),
|
||||
transforms.CenterCrop(crop_size)
|
||||
])
|
||||
|
||||
validset = ImageFolder(IMAGENET_DIR / "val", transform=transform)
|
||||
|
||||
validloader = DataLoader(validset,
|
||||
batch_size = 1,
|
||||
pin_memory = True,
|
||||
num_workers = 1,
|
||||
shuffle=True)
|
||||
|
||||
# top level model
|
||||
resnet_model = torch.load("weights/resnet50_imagenet128_gelu_kurt.pt")
|
||||
resnet_model.eval()
|
||||
|
||||
print(resnet_model)
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
conv1 = resnet_model.conv1
|
||||
bn1 = resnet_model.bn1
|
||||
|
||||
padded_conv1 = pad_conv_input_channels(conv1)
|
||||
|
||||
embedder = copy.deepcopy(torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
|
||||
|
||||
for i, (padded_test_data, test_label) in enumerate(validloader):
|
||||
if i == img_idx:
|
||||
break
|
||||
|
||||
unpadded_test_data = padded_test_data[:,:3]
|
||||
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
|
||||
|
||||
unencrypted = ptxt_embedded
|
||||
|
||||
##############################################################################
|
||||
|
||||
# create HE cc and keys
|
||||
mult_depth = 34
|
||||
scale_factor_bits = 59
|
||||
batch_size = 32 * 32 * 32
|
||||
|
||||
# if using bootstrapping, you must increase scale_factor_bits to 59
|
||||
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
|
||||
|
||||
|
||||
##############################################################################
|
||||
|
||||
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
|
||||
|
||||
while cnn_context.shards[0].getTowersRemaining() > 18:
|
||||
for i in range(cnn_context.num_shards):
|
||||
cnn_context.shards[i] *= 1.0
|
||||
|
||||
start = time()
|
||||
|
||||
# embedding layer
|
||||
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
|
||||
cnn_context = cnn_context.apply_gelu(bound=50.0, degree=200)
|
||||
cnn_context = cnn_context.apply_pool(conv=True)
|
||||
|
||||
compare_accuracy(keys, cnn_context, unencrypted, "embedding")
|
||||
|
||||
###############################################################################
|
||||
|
||||
|
||||
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
|
||||
for j, bottleneck in enumerate(layer):
|
||||
|
||||
name = f"bottleneck #{i+1}-{j}"
|
||||
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=True, gelu_params={"bound" : 15.0, "degree": 59})
|
||||
unencrypted = bottleneck(unencrypted)
|
||||
compare_accuracy(keys, cnn_context, unencrypted, name)
|
||||
|
||||
###############################################################################
|
||||
|
||||
linear = resnet_model.fc
|
||||
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
|
||||
|
||||
inference_time = time() - start
|
||||
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
|
||||
|
||||
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
|
||||
ptxt_logits = linear(flattened)
|
||||
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
|
||||
|
||||
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
|
||||
|
||||
print(f"[+] decrypted logits = {decrypted_logits}")
|
||||
print(f"[+] plaintext logits = {ptxt_logits}")
|
||||
|
||||
###############################################################################
|
||||
|
||||
dataset = "imagenet"
|
||||
model_type = "resnet50_128"
|
||||
|
||||
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
|
||||
filename.parent.mkdir(exist_ok=True, parents=True)
|
||||
data = dict(TIMING_DICT)
|
||||
data["decrypted logits"] = decrypted_logits.tolist()
|
||||
data["unencrypted logits"] = ptxt_logits.tolist()
|
||||
data["inference time"] = inference_time
|
||||
|
||||
# avoid double-counting the strided conv operations
|
||||
data['Pool'] = data['Pool'][:1]
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
155
palisade_he_cnn/inference/resnet50_imagenet256_inference.py
Normal file
155
palisade_he_cnn/inference/resnet50_imagenet256_inference.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
# srun -p hybrid -n 128 --mem=300G --pty bash -i
|
||||
# srun -p himem -n 128 --mem=300G --pty bash -i
|
||||
|
||||
# export OMP_DISPLAY_ENV=TRUE
|
||||
# export OMP_NUM_THREADS=32
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from palisade_he_cnn.src.cnn_context import create_cnn_context, TIMING_DICT
|
||||
from palisade_he_cnn.src.cnn_context.utils import *
|
||||
from palisade_he_cnn.src.he_cnn.utils import get_keys, compare_accuracy
|
||||
from palisade_he_cnn.src.utils import pad_conv_input_channels
|
||||
from palisade_he_cnn.training.utils.utils import PadChannel
|
||||
|
||||
np.set_printoptions(formatter={'float': lambda x: "{0:0.4f}".format(x)})
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--idx", default="0")
|
||||
args = vars(parser.parse_args())
|
||||
img_idx = int(args["idx"])
|
||||
|
||||
print("img_idx", img_idx)
|
||||
|
||||
IMAGENET_CHANNEL_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_CHANNEL_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
stats = (IMAGENET_CHANNEL_MEAN, IMAGENET_CHANNEL_STD)
|
||||
|
||||
IMAGENET_DIR = Path("/aoscluster/he-cnn/vivian/imagenet/datasets/ILSVRC/Data/CLS-LOC")
|
||||
resize_size = 264
|
||||
crop_size = 256
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(*stats, inplace=True),
|
||||
PadChannel(npad=1),
|
||||
transforms.Resize(resize_size),
|
||||
transforms.CenterCrop(crop_size)
|
||||
])
|
||||
|
||||
validset = ImageFolder(IMAGENET_DIR / "val", transform=transform)
|
||||
|
||||
validloader = DataLoader(validset,
|
||||
batch_size=1,
|
||||
pin_memory=True,
|
||||
num_workers=1,
|
||||
shuffle=True)
|
||||
|
||||
# top level model
|
||||
resnet_model = torch.load("weights/resnet50_imagenet256_gelu_kurt.pt")
|
||||
resnet_model.eval()
|
||||
|
||||
print(resnet_model)
|
||||
|
||||
##############################################################################
|
||||
|
||||
conv1 = resnet_model.conv1
|
||||
bn1 = resnet_model.bn1
|
||||
|
||||
padded_conv1 = pad_conv_input_channels(conv1)
|
||||
|
||||
embedder = copy.deepcopy(
|
||||
torch.nn.Sequential(resnet_model.conv1, resnet_model.bn1, resnet_model.relu, resnet_model.maxpool))
|
||||
|
||||
for i, (padded_test_data, test_label) in enumerate(validloader):
|
||||
if i == img_idx:
|
||||
break
|
||||
|
||||
unpadded_test_data = padded_test_data[:, :3]
|
||||
ptxt_embedded = embedder(unpadded_test_data).detach().cpu()
|
||||
|
||||
unencrypted = ptxt_embedded
|
||||
|
||||
##############################################################################
|
||||
|
||||
# create HE cc and keys
|
||||
mult_depth = 34
|
||||
scale_factor_bits = 59
|
||||
batch_size = 32 * 32 * 32
|
||||
|
||||
# if using bootstrapping, you must increase scale_factor_bits to 59
|
||||
cc, keys = get_keys(mult_depth, scale_factor_bits, batch_size, bootstrapping=True)
|
||||
|
||||
##############################################################################
|
||||
|
||||
cnn_context = create_cnn_context(padded_test_data[0], cc, keys.publicKey, verbose=True)
|
||||
|
||||
while cnn_context.shards[0].getTowersRemaining() > 18:
|
||||
for i in range(cnn_context.num_shards):
|
||||
cnn_context.shards[i] *= 1.0
|
||||
|
||||
start = time()
|
||||
|
||||
# embedding layer
|
||||
cnn_context = cnn_context.apply_conv(padded_conv1, bn1)
|
||||
cnn_context = cnn_context.apply_gelu(bound=50.0, degree=200)
|
||||
cnn_context = cnn_context.apply_pool(conv=True)
|
||||
|
||||
compare_accuracy(keys, cnn_context, unencrypted, "embedding")
|
||||
|
||||
###############################################################################
|
||||
|
||||
|
||||
for i, layer in enumerate([resnet_model.layer1, resnet_model.layer2, resnet_model.layer3, resnet_model.layer4]):
|
||||
for j, bottleneck in enumerate(layer):
|
||||
name = f"bottleneck #{i + 1}-{j}"
|
||||
cnn_context = cnn_context.apply_bottleneck(bottleneck, bootstrap=True,
|
||||
gelu_params={"bound": 15.0, "degree": 59})
|
||||
unencrypted = bottleneck(unencrypted)
|
||||
compare_accuracy(keys, cnn_context, unencrypted, name)
|
||||
|
||||
###############################################################################
|
||||
|
||||
linear = resnet_model.fc
|
||||
ctxt_logits = cnn_context.apply_fused_pool_linear(linear)
|
||||
|
||||
inference_time = time() - start
|
||||
print(f"\nTotal Time: {inference_time:.0f} s = {inference_time / 60:.01f} min")
|
||||
|
||||
flattened = torch.nn.Flatten()(resnet_model.avgpool(unencrypted))
|
||||
ptxt_logits = linear(flattened)
|
||||
ptxt_logits = ptxt_logits.detach().cpu().numpy().ravel()
|
||||
|
||||
decrypted_logits = cc.decrypt(keys.secretKey, ctxt_logits)[:linear.out_features]
|
||||
|
||||
print(f"[+] decrypted logits = {decrypted_logits}")
|
||||
print(f"[+] plaintext logits = {ptxt_logits}")
|
||||
|
||||
###############################################################################
|
||||
|
||||
dataset = "imagenet"
|
||||
model_type = "resnet50_256"
|
||||
|
||||
filename = Path("logs") / dataset / model_type / f"log_{img_idx}.json"
|
||||
filename.parent.mkdir(exist_ok=True, parents=True)
|
||||
data = dict(TIMING_DICT)
|
||||
data["decrypted logits"] = decrypted_logits.tolist()
|
||||
data["unencrypted logits"] = ptxt_logits.tolist()
|
||||
data["inference time"] = inference_time
|
||||
|
||||
# avoid double-counting the strided conv operations
|
||||
data['Pool'] = data['Pool'][:1]
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
17
palisade_he_cnn/inference/run_inference.sh
Normal file
17
palisade_he_cnn/inference/run_inference.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# srun -p hybrid -n 128 --mem=800G --pty bash -i
|
||||
|
||||
# ./run_vivian.sh 2>&1 >> logs/imagenet/resnet50_256_log_v2.txt
|
||||
|
||||
export OMP_DISPLAY_ENV=TRUE
|
||||
export OMP_NUM_THREADS=64
|
||||
export OMP_PROC_BIND=TRUE
|
||||
|
||||
for i in `seq 0 50` ; do
|
||||
python resnet50_cifar_inference.py -i $i
|
||||
# python resnet50_imagenet128_inference.py -i $i
|
||||
# python resnet50_imagenet256_inference.py -i $i
|
||||
done
|
||||
386
palisade_he_cnn/notebooks/analyze_layers_logits.ipynb
Normal file
386
palisade_he_cnn/notebooks/analyze_layers_logits.ipynb
Normal file
File diff suppressed because one or more lines are too long
340
palisade_he_cnn/src/cnn_context.py
Normal file
340
palisade_he_cnn/src/cnn_context.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from time import time
|
||||
from collections import defaultdict
|
||||
import palisade_he_cnn.src.he_cnn.utils as utils
|
||||
import palisade_he_cnn.src.he_cnn.conv as conv
|
||||
import palisade_he_cnn.src.he_cnn.pool as pool
|
||||
import palisade_he_cnn.src.he_cnn.linear as linear
|
||||
|
||||
from pyOpenFHE import CKKS as pal
|
||||
|
||||
TIMING_DICT = defaultdict(list)
|
||||
|
||||
DUPLICATED, IMAGE_SHARDED, CHANNEL_SHARDED = range(3)
|
||||
|
||||
|
||||
def reset_timing_dict():
|
||||
global TIMING_DICT
|
||||
TIMING_DICT.clear()
|
||||
|
||||
|
||||
# an image is a PyTorch 3-tensor
|
||||
def create_cnn_context(image, cc, publicKey, verbose=False):
|
||||
# create these to encrypt
|
||||
shard_size = cc.getBatchSize()
|
||||
|
||||
if len(image.shape) != 3:
|
||||
raise ValueError("Input image must be a PyTorch 3-tensor")
|
||||
|
||||
# do we want to address rectangular images at some point...?
|
||||
if image.shape[1] != image.shape[2]:
|
||||
raise ValueError("Non-square channels not currently supported")
|
||||
|
||||
if not utils.is_power_of_2(image.shape[0]):
|
||||
raise ValueError("Number of channels must be a power-of-two")
|
||||
|
||||
if not utils.is_power_of_2(image.shape[1]):
|
||||
raise ValueError("Image dimensions must be a power-of-two")
|
||||
|
||||
mtx_size = image.shape[1]
|
||||
num_channels = image.shape[0]
|
||||
total_size = mtx_size * mtx_size * num_channels
|
||||
num_shards = math.ceil(total_size / shard_size)
|
||||
|
||||
if total_size <= shard_size:
|
||||
duplication_factor = shard_size // total_size
|
||||
else:
|
||||
duplication_factor = 1
|
||||
|
||||
duplicated_image = np.repeat(image.numpy(), duplication_factor, axis=0).flatten()
|
||||
shards = []
|
||||
for s in range(num_shards):
|
||||
shard = cc.encrypt(publicKey, duplicated_image[shard_size * s: shard_size * (s + 1)])
|
||||
shards.append(shard)
|
||||
|
||||
# return the cc and keys as well for decryption at the end
|
||||
cnn_context = CNNContext(shards, mtx_size, num_channels, permutation=None, verbose=verbose)
|
||||
|
||||
return cnn_context
|
||||
|
||||
|
||||
def timing_decorator_factory(prefix=""):
|
||||
def timing_decorator(func):
|
||||
def wrapper_function(*args, **kwargs):
|
||||
global TIMING_DICT
|
||||
start = time()
|
||||
res = func(*args, **kwargs)
|
||||
layer_time = time() - start
|
||||
self = args[0]
|
||||
TIMING_DICT[prefix.strip()].append(layer_time)
|
||||
if self.verbose:
|
||||
print(prefix + f"Layer took {layer_time:.02f} seconds")
|
||||
return res
|
||||
|
||||
return wrapper_function
|
||||
|
||||
return timing_decorator
|
||||
|
||||
|
||||
class CNNContext:
|
||||
r"""This class contains methods for applying network layers to an image."""
|
||||
|
||||
def __init__(self, shards, mtx_size, num_channels, permutation=None, verbose=False):
|
||||
r"""Initializes the CNNContext object. We only needs shards and channel/matrix size to compute all other metadata."""
|
||||
|
||||
if permutation is None:
|
||||
permutation = np.array(range(num_channels))
|
||||
|
||||
self.shards = shards
|
||||
self.mtx_size = mtx_size
|
||||
self.num_channels = num_channels
|
||||
self.permutation = permutation
|
||||
self.verbose = verbose
|
||||
|
||||
self.compute_metadata()
|
||||
|
||||
def compute_metadata(self):
|
||||
# Shard information
|
||||
self.num_shards = len(self.shards)
|
||||
self.shard_size = self.shards[0].getBatchSize()
|
||||
self.total_size = self.num_shards * self.shard_size
|
||||
|
||||
# Channel information
|
||||
self.channel_size = self.mtx_size * self.mtx_size
|
||||
|
||||
# Duplication factor
|
||||
self.duplication_factor = (self.total_size // self.channel_size) // self.num_channels
|
||||
|
||||
if self.duplication_factor > 1:
|
||||
self.shard_type = DUPLICATED
|
||||
elif self.channel_size <= self.shard_size:
|
||||
self.shard_type = IMAGE_SHARDED
|
||||
else:
|
||||
self.shard_type = CHANNEL_SHARDED
|
||||
|
||||
# Channel and shard info
|
||||
self.num_phys_chan_per_shard = self.shard_size // self.channel_size
|
||||
self.num_phys_chan_total = self.num_shards * self.num_phys_chan_per_shard
|
||||
self.num_log_chan_per_shard = self.num_phys_chan_per_shard // self.duplication_factor
|
||||
self.num_log_chan_total = self.num_shards * self.num_log_chan_per_shard
|
||||
|
||||
def print_metadata(self):
|
||||
# shard information
|
||||
print(f"num_shards: {self.num_shards}")
|
||||
print(f"shard_size: {self.shard_size}")
|
||||
print(f"total_size: {self.total_size}")
|
||||
|
||||
# Channel information
|
||||
print(f"channel_size: {self.channel_size}")
|
||||
|
||||
# Duplication factor
|
||||
print(f"duplication_factor: {self.duplication_factor}")
|
||||
print(f"shard_type: {self.shard_type}")
|
||||
|
||||
# Channel and shard info
|
||||
print(f"num_phys_chan_per_shard: {self.num_phys_chan_per_shard}")
|
||||
print(f"num_phys_chan_total: {self.num_phys_chan_total}")
|
||||
print(f"num_log_chan_per_shard: {self.num_log_chan_per_shard}")
|
||||
print(f"num_log_chan_total: {self.num_log_chan_total}")
|
||||
|
||||
def decrypt_to_tensor(self, cc, keys):
|
||||
# decrypt the shards
|
||||
decrypted_shards = [cc.decrypt(keys.secretKey, shard) for shard in self.shards]
|
||||
decrypted_output = np.concatenate(decrypted_shards)
|
||||
|
||||
# reshape with possible duplication
|
||||
duplicated_output = decrypted_output.reshape(
|
||||
self.num_channels * self.duplication_factor,
|
||||
self.mtx_size,
|
||||
self.mtx_size
|
||||
)
|
||||
|
||||
decrypted_deduplicated_output = duplicated_output[0 :: self.duplication_factor]
|
||||
|
||||
return torch.from_numpy(decrypted_deduplicated_output)
|
||||
|
||||
@timing_decorator_factory("Conv ")
|
||||
def apply_conv(self, conv_layer, bn_layer=None, output_permutation=None, drop_levels=False):
|
||||
pal.CNN.omp_set_nested(0)
|
||||
# pal.CNN.omp_set_dynamic(0)
|
||||
|
||||
# Get filters, biases
|
||||
filters, biases = utils.get_filters_and_biases_from_conv2d(conv_layer)
|
||||
|
||||
# Get batch norm info if one is passed in
|
||||
if bn_layer:
|
||||
scale, shift = utils.get_scale_and_shift_from_bn(bn_layer)
|
||||
else:
|
||||
scale = None
|
||||
shift = None
|
||||
|
||||
num_out_channels = filters.shape[1]
|
||||
if output_permutation is None:
|
||||
output_permutation = np.array(range(num_out_channels))
|
||||
elif len(output_permutation) != num_out_channels:
|
||||
raise ValueError("output permutation is incorrect length")
|
||||
|
||||
# TODO this should be a Compress() call
|
||||
if drop_levels:
|
||||
L = self.shards[0].getTowersRemaining() - 4
|
||||
for j in range(self.num_shards):
|
||||
for i in range(L):
|
||||
self.shards[j] *= 1.0
|
||||
|
||||
# Apply conv
|
||||
new_shards = conv.conv2d(
|
||||
ciphertext_shards=self.shards,
|
||||
filters=filters,
|
||||
mtx_size=self.mtx_size,
|
||||
biases=biases,
|
||||
permutation=self.permutation,
|
||||
bn_scale=scale,
|
||||
bn_shift=shift,
|
||||
output_permutation=output_permutation
|
||||
)
|
||||
|
||||
# Create new CNN Context
|
||||
stride = conv_layer.stride
|
||||
cnn_context = CNNContext(new_shards, self.mtx_size, num_out_channels, output_permutation, self.verbose)
|
||||
if stride == (1, 1):
|
||||
return cnn_context
|
||||
elif stride == (2, 2):
|
||||
return cnn_context.apply_pool(conv=False)
|
||||
else:
|
||||
raise ValueError("Unsupported stride: {stride}")
|
||||
|
||||
@timing_decorator_factory("Pool ")
|
||||
def apply_pool(self, conv=True):
|
||||
pal.CNN.omp_set_nested(0)
|
||||
# pal.CNN.omp_set_dynamic(0)
|
||||
|
||||
# Apply pool
|
||||
new_shards = pool.pool(self.shards, self.mtx_size, conv)
|
||||
|
||||
# Get permutation
|
||||
new_permutation = pool.get_pool_permutation(self.shards, self.num_channels, self.mtx_size)
|
||||
new_permutation = pool.compose_permutations(self.permutation, new_permutation)
|
||||
new_permutation = np.array(new_permutation)
|
||||
|
||||
# Create new CNN Context
|
||||
return CNNContext(new_shards, self.mtx_size // 2, self.num_channels, new_permutation, self.verbose)
|
||||
|
||||
@timing_decorator_factory("Fused adaptive pool and linear ")
|
||||
def apply_fused_pool_linear(self, linear_layer):
|
||||
has_bias = hasattr(linear_layer, "bias")
|
||||
return self.apply_linear(linear_layer, has_bias, pool_factor=self.mtx_size)
|
||||
|
||||
@timing_decorator_factory("Bottleneck block ")
|
||||
def apply_bottleneck(self, bottleneck_block, debug=False, gelu_params={}, bootstrap_params={}, bootstrap=True):
|
||||
# Bottleneck block's forward pass is here: https://pytorch.org/vision/0.8/_modules/torchvision/models/resnet.html
|
||||
|
||||
skip_connection = self
|
||||
downsample_block = bottleneck_block.downsample
|
||||
if downsample_block:
|
||||
conv_downsample_layer = downsample_block[0]
|
||||
bn_downsample_layer = downsample_block[1]
|
||||
skip_connection = skip_connection.apply_conv(conv_downsample_layer, bn_downsample_layer)
|
||||
|
||||
conv1_layer = bottleneck_block.conv1
|
||||
bn1_layer = bottleneck_block.bn1
|
||||
cnn_context = self.apply_conv(conv1_layer, bn1_layer)
|
||||
|
||||
if not debug:
|
||||
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
|
||||
cnn_context = cnn_context.apply_gelu(**gelu_params)
|
||||
|
||||
conv2_layer = bottleneck_block.conv2
|
||||
bn2_layer = bottleneck_block.bn2
|
||||
cnn_context = cnn_context.apply_conv(conv2_layer, bn2_layer)
|
||||
|
||||
if not debug:
|
||||
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
|
||||
cnn_context = cnn_context.apply_gelu(**gelu_params)
|
||||
|
||||
conv3_layer = bottleneck_block.conv3
|
||||
bn3_layer = bottleneck_block.bn3
|
||||
cnn_context = cnn_context.apply_conv(conv3_layer, bn3_layer, output_permutation=skip_connection.permutation)
|
||||
|
||||
cnn_context = cnn_context.apply_residual(skip_connection)
|
||||
|
||||
if not debug:
|
||||
if bootstrap: cnn_context = cnn_context.apply_bootstrapping(**bootstrap_params)
|
||||
cnn_context = cnn_context.apply_gelu(**gelu_params)
|
||||
|
||||
return cnn_context
|
||||
|
||||
# This operation doesn't return a CNNContext, that's returned by linear
|
||||
@timing_decorator_factory("Linear ")
|
||||
def apply_linear(self, linear_layer, bias=True, scale=1.0, pool_factor=1):
|
||||
pal.CNN.omp_set_nested(0)
|
||||
pal.CNN.omp_set_dynamic(1)
|
||||
|
||||
linear_weights, linear_biases = utils.get_weights_and_biases_from_linear(linear_layer,
|
||||
self.mtx_size,
|
||||
bias,
|
||||
pool_factor)
|
||||
final_shard = linear.linear(self.shards, linear_weights, linear_biases, self.mtx_size, self.permutation, scale,
|
||||
pool_factor)
|
||||
|
||||
return final_shard
|
||||
|
||||
@timing_decorator_factory("Square ")
|
||||
def apply_square(self):
|
||||
new_shards = [shard * shard for shard in self.shards]
|
||||
|
||||
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
|
||||
|
||||
@timing_decorator_factory("GELU ")
|
||||
def apply_gelu(self, bound=10.0, degree=59):
|
||||
"""
|
||||
bound:
|
||||
bound = an upper bound on the absolute value of the inputs.
|
||||
the polynomial approximation is valid for [-bound, bound]
|
||||
degree:
|
||||
degree of Chebyshev polynomial
|
||||
"""
|
||||
if self.num_shards < 8:
|
||||
pal.CNN.omp_set_nested(1)
|
||||
pal.CNN.omp_set_dynamic(1)
|
||||
else:
|
||||
pal.CNN.omp_set_nested(0)
|
||||
pal.CNN.omp_set_dynamic(1)
|
||||
|
||||
# TODO this can be absorbed into the BN
|
||||
new_shards = [x * (1 / bound) for x in self.shards]
|
||||
new_shards = pal.CNN.fhe_gelu(new_shards, degree, bound)
|
||||
|
||||
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
|
||||
|
||||
@timing_decorator_factory("Bootstrapping ")
|
||||
def apply_bootstrapping(self, meta=False):
|
||||
if self.num_shards < 8:
|
||||
pal.CNN.omp_set_nested(1)
|
||||
pal.CNN.omp_set_dynamic(1)
|
||||
else:
|
||||
pal.CNN.omp_set_nested(0)
|
||||
pal.CNN.omp_set_dynamic(1)
|
||||
|
||||
cc = self.shards[0].getCryptoContext()
|
||||
if meta:
|
||||
new_shards = cc.evalMetaBootstrap(self.shards)
|
||||
else:
|
||||
new_shards = cc.evalBootstrap(self.shards)
|
||||
|
||||
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
|
||||
|
||||
@timing_decorator_factory("Residual ")
|
||||
def apply_residual(self, C2):
|
||||
if len(self.permutation) != len(C2.permutation):
|
||||
raise ValueError("Incompatible number of channels")
|
||||
if self.mtx_size != C2.mtx_size:
|
||||
raise ValueError("Incompatible matrix size")
|
||||
if any([i != j for i, j in zip(self.permutation, C2.permutation)]):
|
||||
raise ValueError("Incompatible permutations")
|
||||
|
||||
new_shards = [i + j for i, j in zip(self.shards, C2.shards)]
|
||||
return CNNContext(new_shards, self.mtx_size, self.num_channels, self.permutation, self.verbose)
|
||||
166
palisade_he_cnn/src/he_cnn/activations.py
Normal file
166
palisade_he_cnn/src/he_cnn/activations.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import math
|
||||
from copy import copy
|
||||
|
||||
SIN_COEFFS = [
|
||||
0,
|
||||
9.99984594193494365437e-01,
|
||||
0,
|
||||
-1.66632595072086745320e-01,
|
||||
0,
|
||||
8.31238887417884598346e-03,
|
||||
0,
|
||||
-1.93162796407356830500e-04,
|
||||
0,
|
||||
2.17326217498596729611e-06,
|
||||
]
|
||||
COS_COEFFS = [
|
||||
9.99971094606182687341e-01,
|
||||
0,
|
||||
-4.99837602272995734437e-01,
|
||||
0,
|
||||
4.15223086250910767516e-02,
|
||||
0,
|
||||
-1.34410769349285321733e-03,
|
||||
0,
|
||||
1.90652668840074246305e-05,
|
||||
0,
|
||||
]
|
||||
# you technically don't need the . to specify float division in python3
|
||||
LOG_COEFFS = [
|
||||
0,
|
||||
1,
|
||||
-0.5,
|
||||
1.0 / 3,
|
||||
-1.0 / 4,
|
||||
1.0 / 5,
|
||||
-1.0 / 6,
|
||||
1.0 / 7,
|
||||
-1.0 / 8,
|
||||
1.0 / 9,
|
||||
-1.0 / 10,
|
||||
]
|
||||
EXP_COEFFS = [
|
||||
1,
|
||||
1,
|
||||
0.5,
|
||||
1.0 / 6,
|
||||
1.0 / 24,
|
||||
1.0 / 120,
|
||||
1.0 / 720,
|
||||
1.0 / 5040,
|
||||
1.0 / 40320,
|
||||
1.0 / 362880,
|
||||
1.0 / 3628800,
|
||||
]
|
||||
SIGMOID_COEFFS = [
|
||||
1.0 / 2,
|
||||
1.0 / 4,
|
||||
0,
|
||||
-1.0 / 48,
|
||||
0,
|
||||
1.0 / 480,
|
||||
0,
|
||||
-17.0 / 80640,
|
||||
0,
|
||||
31.0 / 1451520,
|
||||
0,
|
||||
]
|
||||
|
||||
|
||||
def powerOf2Extended(cipher, logDegree):
|
||||
res = [copy(cipher)]
|
||||
for i in range(logDegree):
|
||||
t = res[-1]
|
||||
res.append(t * t)
|
||||
return res
|
||||
|
||||
|
||||
def powerExtended(cipher, degree):
|
||||
res = []
|
||||
logDegree = int(
|
||||
math.log2(degree)
|
||||
) # both python and C++ truncate when casting float->int
|
||||
cpows = powerOf2Extended(cipher, logDegree)
|
||||
|
||||
idx = 0
|
||||
for i in range(logDegree):
|
||||
powi = pow(2, i)
|
||||
res.append(cpows[i])
|
||||
|
||||
for j in range(powi - 1):
|
||||
res.append(copy(res[j]))
|
||||
res[-1] *= cpows[i]
|
||||
|
||||
res.append(cpows[logDegree])
|
||||
|
||||
degree2 = pow(2, logDegree)
|
||||
|
||||
for i in range(degree - degree2):
|
||||
res.append(copy(res[i]))
|
||||
res[-1] *= cpows[logDegree]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def polynomial_series_function(cipher, coeffs, verbose=False):
|
||||
"""
|
||||
Cipher is a CKKSCiphertext, coeffs should be array-like (generally either native list or numpy array)
|
||||
"""
|
||||
degree = len(coeffs)
|
||||
|
||||
if verbose:
|
||||
print("initial ciphertext level = {}".format(cipher.getTowersRemaining()))
|
||||
|
||||
cpows = powerExtended(cipher, degree) # array of ciphertexts
|
||||
|
||||
# cpows[0] == cipher, i.e. x^1
|
||||
res = cpows[0] * coeffs[1] # this should be defined
|
||||
res += coeffs[0]
|
||||
|
||||
for i in range(2, degree):
|
||||
coeff = coeffs[i]
|
||||
if abs(coeff) > 1e-27:
|
||||
aixi = cpows[i - 1] * coeff
|
||||
res += aixi
|
||||
|
||||
if verbose:
|
||||
print("final ciphertext level = {}".format(res.getTowersRemaining()))
|
||||
|
||||
return res
|
||||
|
||||
|
||||
"""
|
||||
example:
|
||||
|
||||
to approximate the sine function, do:
|
||||
polynomial_series_function(c1, SIN_COEFFS)
|
||||
"""
|
||||
|
||||
|
||||
def sqrt_helper(cipher, steps):
|
||||
a = copy(cipher)
|
||||
b = a - 1
|
||||
|
||||
for i in range(steps):
|
||||
a *= 1 - (0.5 * b)
|
||||
|
||||
# there must be a better way to do this...
|
||||
if i < steps - 1:
|
||||
b = (b * b) * (0.25 * (b - 3))
|
||||
|
||||
return a
|
||||
|
||||
|
||||
def sqrt(cipher, steps, upper_bound):
|
||||
if upper_bound == 1:
|
||||
return sqrt_helper(cipher, steps)
|
||||
return sqrt_helper(cipher * (1 / upper_bound), steps) * math.sqrt(upper_bound)
|
||||
|
||||
|
||||
def relu(cipher, steps, upper_bound):
|
||||
x = cipher * cipher
|
||||
|
||||
res = cipher + sqrt(x, steps, upper_bound)
|
||||
return 0.5 * res
|
||||
56
palisade_he_cnn/src/he_cnn/conv.py
Normal file
56
palisade_he_cnn/src/he_cnn/conv.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
from pyOpenFHE import CKKS as pal
|
||||
|
||||
conv2d_cpp = pal.CNN.conv2d
|
||||
|
||||
|
||||
def conv2d(ciphertext_shards, filters, mtx_size, biases, permutation=None, bn_scale=None, bn_shift=None,
|
||||
output_permutation=None):
|
||||
# if we're combining with a batch norm, fold the batch norm scale factor into the filters
|
||||
# with sharded convolutions, filters are not duplicated or permuted in any way.
|
||||
scaled_filters = filters
|
||||
if bn_scale is not None and bn_shift is not None:
|
||||
scaled_filters = filters * bn_scale.reshape(1, -1, 1, 1)
|
||||
|
||||
# if we're combining with a batch norm, fold the batch norm shift factor into the biases
|
||||
shifted_biases = biases
|
||||
if bn_scale is not None and bn_shift is not None:
|
||||
shifted_biases = biases * bn_scale + bn_shift
|
||||
|
||||
# all of this should happen somewhere in the CNNContext class
|
||||
shard_size = ciphertext_shards[0].getBatchSize()
|
||||
num_out_channels = filters.shape[1]
|
||||
channel_size = mtx_size * mtx_size
|
||||
if channel_size < shard_size:
|
||||
channels_per_shard = shard_size // (mtx_size * mtx_size)
|
||||
output_dup_factor = math.ceil(channels_per_shard / num_out_channels)
|
||||
else:
|
||||
output_dup_factor = 1
|
||||
|
||||
num_in_channels = filters.shape[0]
|
||||
if permutation is None:
|
||||
permutation = np.array(range(num_in_channels))
|
||||
|
||||
if output_permutation is None:
|
||||
output_permutation = np.array(range(num_out_channels))
|
||||
|
||||
if len(permutation) != num_in_channels:
|
||||
raise ValueError("incorrect number of input channels")
|
||||
|
||||
if len(output_permutation) != num_out_channels:
|
||||
raise ValueError("incorrect number of output channels")
|
||||
|
||||
scaled_filters = scaled_filters[:, output_permutation, :, :]
|
||||
shifted_biases = shifted_biases[output_permutation]
|
||||
|
||||
# compute the convolution
|
||||
conv_shards = conv2d_cpp(ciphertext_shards, scaled_filters, mtx_size, permutation)
|
||||
|
||||
repeated_shifted_biases = np.repeat(shifted_biases, mtx_size * mtx_size * output_dup_factor)
|
||||
for s in range(len(conv_shards)):
|
||||
conv_shards[s] += repeated_shifted_biases[s * shard_size: (s + 1) * shard_size]
|
||||
|
||||
return conv_shards
|
||||
28
palisade_he_cnn/src/he_cnn/linear.py
Normal file
28
palisade_he_cnn/src/he_cnn/linear.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import numpy as np
|
||||
from pyOpenFHE import CKKS as pal_ckks
|
||||
|
||||
linear_cpp = pal_ckks.CNN.linear
|
||||
|
||||
|
||||
def linear(channel_shards, weights, biases, mtx_size, permutation=None, scale=1.0, pool_factor=1):
|
||||
shard_size = channel_shards[0].getBatchSize()
|
||||
num_shards = len(channel_shards)
|
||||
num_inputs = weights.shape[1]
|
||||
channel_size = mtx_size * mtx_size
|
||||
duplication_factor = max(shard_size // num_inputs, 1)
|
||||
num_physical_channels_per_shard = shard_size // channel_size
|
||||
num_physical_channels = num_physical_channels_per_shard * num_shards
|
||||
num_logical_channels = num_physical_channels // duplication_factor
|
||||
|
||||
if permutation is None:
|
||||
permutation = np.array(range(num_logical_channels))
|
||||
|
||||
output = linear_cpp(channel_shards, weights * scale, mtx_size, permutation, pool_factor)
|
||||
|
||||
# FO: if np.all(biases==0), then we do not need to compute biases*scale
|
||||
num_out_activs = biases.shape[0]
|
||||
output += np.pad(biases * scale, [(0, shard_size - num_out_activs)])
|
||||
|
||||
return output
|
||||
72
palisade_he_cnn/src/he_cnn/pool.py
Normal file
72
palisade_he_cnn/src/he_cnn/pool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
from .utils import *
|
||||
import math
|
||||
from pyOpenFHE import CKKS as pal
|
||||
|
||||
pool = pal.CNN.pool
|
||||
|
||||
|
||||
def divide_chunks(l, n):
|
||||
# looping till length l
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
def interleave_lists(lists):
|
||||
return [val for tup in zip(*lists) for val in tup]
|
||||
|
||||
|
||||
def invert_permutation(P):
|
||||
inverse_permutation = [0] * len(P)
|
||||
for i, v in enumerate(P):
|
||||
inverse_permutation[v] = i
|
||||
return inverse_permutation
|
||||
|
||||
|
||||
def compose_permutations(P1, P2):
|
||||
if len(P1) != len(P2):
|
||||
raise ValueError("permutations must have equal size")
|
||||
permutation = [P1[P2[i]] for i in range(len(P1))]
|
||||
return permutation
|
||||
|
||||
|
||||
"""
|
||||
metadata includes:
|
||||
- the new channel permutation
|
||||
- the duplication factor
|
||||
- the new number of shards
|
||||
"""
|
||||
|
||||
|
||||
def get_pool_permutation(shards, num_channels, mtx_size):
|
||||
initial_num_shards = len(shards)
|
||||
shard_size = shards[0].getBatchSize()
|
||||
channel_size = mtx_size * mtx_size
|
||||
initial_num_physical_channels_per_shard = math.ceil(shard_size / channel_size)
|
||||
num_physical_channels = initial_num_shards * initial_num_physical_channels_per_shard
|
||||
initial_dup_factor = math.ceil(num_physical_channels / num_channels)
|
||||
|
||||
# if we have channel sharding, then no permutation
|
||||
if channel_size >= shard_size:
|
||||
C = num_channels
|
||||
P = list(range(C))
|
||||
return P
|
||||
|
||||
if (initial_dup_factor > 1) and (initial_num_shards > 1):
|
||||
raise ValueError("Should not have both duplication and shards at the same time")
|
||||
|
||||
# if we have duplication, then no permutation
|
||||
if initial_dup_factor > 1:
|
||||
C = initial_num_physical_channels_per_shard // initial_dup_factor
|
||||
P = list(range(C))
|
||||
return P
|
||||
|
||||
C = initial_num_physical_channels_per_shard * initial_num_shards
|
||||
I = list(range(C))
|
||||
I = list(divide_chunks(I, initial_num_physical_channels_per_shard))
|
||||
I = list(divide_chunks(I, 4))
|
||||
P = [interleave_lists(J) for J in I]
|
||||
P = sum(P, start=[])
|
||||
|
||||
return np.array(P)
|
||||
105
palisade_he_cnn/src/he_cnn/upsample.py
Normal file
105
palisade_he_cnn/src/he_cnn/upsample.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from pyOpenFHE import CKKS as pal
|
||||
|
||||
upsample_cpp = pal.CNN.upsample
|
||||
|
||||
|
||||
def divide_chunks(l, n):
|
||||
# looping till length l
|
||||
for i in range(0, len(l), n):
|
||||
yield l[i:i + n]
|
||||
|
||||
|
||||
def interleave_lists(lists):
|
||||
return [val for tup in zip(*lists) for val in tup]
|
||||
|
||||
|
||||
def invert_permutation(P):
|
||||
inverse_permutation = [0] * len(P)
|
||||
for i, v in enumerate(P):
|
||||
inverse_permutation[v] = i
|
||||
return inverse_permutation
|
||||
|
||||
|
||||
def compose_permutations(P1, P2):
|
||||
if len(P1) != len(P2):
|
||||
raise ValueError("permutations must have equal size")
|
||||
permutation = [P1[P2[i]] for i in range(len(P1))]
|
||||
return permutation
|
||||
|
||||
|
||||
"""
|
||||
metadata includes:
|
||||
- the new channel permutation
|
||||
- the duplication factor
|
||||
- the new number of shards
|
||||
"""
|
||||
|
||||
|
||||
def get_upsample_permutation(shards, num_channels, mtx_size):
|
||||
initial_num_shards = len(shards)
|
||||
shard_size = shards[0].getBatchSize()
|
||||
channel_size = mtx_size * mtx_size
|
||||
initial_num_physical_channels_per_shard = math.ceil(shard_size / channel_size)
|
||||
final_num_physical_channels_per_shard = math.ceil(shard_size / channel_size / 4)
|
||||
num_physical_channels = initial_num_shards * initial_num_physical_channels_per_shard
|
||||
initial_dup_factor = math.ceil(num_physical_channels / num_channels)
|
||||
|
||||
# if we start with channel sharding, then no permutation
|
||||
if channel_size >= shard_size:
|
||||
P = list(range(num_channels))
|
||||
return P
|
||||
|
||||
if (initial_dup_factor > 1) and (initial_num_shards > 1):
|
||||
raise ValueError("Should not have both duplication and shards at the same time")
|
||||
|
||||
# if we have duplication factor >= 4, then no permutation
|
||||
if initial_dup_factor > 2:
|
||||
P = list(range(num_channels))
|
||||
return P
|
||||
|
||||
# if we have two-fold duplication
|
||||
if initial_dup_factor == 2:
|
||||
P = list(range(num_channels))
|
||||
if num_channels == 1: return P
|
||||
P = P[::2] + P[1::2]
|
||||
return P
|
||||
|
||||
I = list(range(num_channels))
|
||||
I = list(divide_chunks(I, initial_num_physical_channels_per_shard))
|
||||
I = [list(divide_chunks(J, 4)) for J in I]
|
||||
P = [interleave_lists(J) for J in I]
|
||||
P = sum(P, start=[])
|
||||
|
||||
return np.array(P)
|
||||
|
||||
|
||||
"""
|
||||
This takes a permuted list of ciphertexts stored using channel sharding,
|
||||
and it reorders them into the identity permutation.
|
||||
|
||||
mtx_size and permutation refer to the values after upsampling, not of the input shards
|
||||
"""
|
||||
|
||||
|
||||
def undo_channel_sharding_permutation(shards, num_channels, mtx_size, permutation):
|
||||
num_shards = len(shards)
|
||||
shard_size = shards[0].getBatchSize()
|
||||
channel_size = mtx_size * mtx_size
|
||||
|
||||
if shard_size > channel_size:
|
||||
raise ValueError("This function should only be called on a channel sharded image")
|
||||
|
||||
num_shards_per_channel = channel_size // shard_size
|
||||
|
||||
final_shards = [None for _ in range(num_shards)]
|
||||
for i, x in enumerate(shards):
|
||||
channel_idx = i // num_shards_per_channel
|
||||
subshard_idx = i % num_shards_per_channel
|
||||
correct_idx = permutation[channel_idx] * num_shards_per_channel + subshard_idx
|
||||
final_shards[correct_idx] = x
|
||||
|
||||
return final_shards
|
||||
228
palisade_he_cnn/src/he_cnn/utils.py
Normal file
228
palisade_he_cnn/src/he_cnn/utils.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyOpenFHE as pal
|
||||
from pyOpenFHE import CKKS as pal_ckks
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
serial = pal_ckks.serial
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and x & (x - 1) == 0
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
p = 1
|
||||
if n and not (n & (n - 1)):
|
||||
return n
|
||||
while p < n:
|
||||
p <<= 1
|
||||
return p
|
||||
|
||||
|
||||
def load_cc_and_keys(batch_size, mult_depth=10, scale_factor_bits=40, bootstrapping=False):
|
||||
f = "{}-{}-{}-{}".format(batch_size, mult_depth, scale_factor_bits, int(bootstrapping))
|
||||
path = Path("serialized") / f
|
||||
|
||||
P = (path / "PublicKey.bin").as_posix()
|
||||
publicKey = pal_ckks.serial.DeserializeFromFile_PublicKey(P, pal_ckks.serial.SerType.BINARY)
|
||||
P = (path / "PrivateKey.bin").as_posix()
|
||||
secretKey = pal_ckks.serial.DeserializeFromFile_PrivateKey(P, pal_ckks.serial.SerType.BINARY)
|
||||
|
||||
keys = pal_ckks.KeyPair(publicKey, secretKey)
|
||||
cc = publicKey.getCryptoContext()
|
||||
|
||||
P = (path / "EvalMultKey.bin").as_posix()
|
||||
pal_ckks.serial.DeserializeFromFile_EvalMultKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
|
||||
P = (path / "EvalAutomorphismKey.bin").as_posix()
|
||||
pal_ckks.serial.DeserializeFromFile_EvalAutomorphismKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
|
||||
|
||||
if bootstrapping:
|
||||
cc.evalBootstrapSetup()
|
||||
|
||||
return cc, keys
|
||||
|
||||
|
||||
def save_cc_and_keys(cc, keys, path):
|
||||
P = (path / "PublicKey.bin").as_posix()
|
||||
assert pal_ckks.serial.SerializeToFile(P, keys.publicKey, pal_ckks.serial.SerType.BINARY)
|
||||
P = (path / "PrivateKey.bin").as_posix()
|
||||
assert pal_ckks.serial.SerializeToFile(P, keys.secretKey, pal_ckks.serial.SerType.BINARY)
|
||||
P = (path / "EvalMultKey.bin").as_posix()
|
||||
assert pal_ckks.serial.SerializeToFile_EvalMultKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
|
||||
P = (path / "EvalAutomorphismKey.bin").as_posix()
|
||||
assert pal_ckks.serial.SerializeToFile_EvalAutomorphismKey_CryptoContext(cc, P, pal_ckks.serial.SerType.BINARY)
|
||||
|
||||
|
||||
def create_cc_and_keys(batch_size, mult_depth=10, scale_factor_bits=40, bootstrapping=False, save=False):
|
||||
# We make use of palisade HE by creating a crypto context object
|
||||
# this specifies things like multiplicative depth
|
||||
cc = pal_ckks.genCryptoContextCKKS(
|
||||
mult_depth, # number of multiplications you can perform
|
||||
scale_factor_bits, # kindof like number of bits of precision
|
||||
batch_size, # length of your vector, can be any power-of-2 up to 2^14
|
||||
)
|
||||
|
||||
print(f"CKKS scheme is using ring dimension = {cc.getRingDimension()}, batch size = {cc.getBatchSize()}")
|
||||
|
||||
cc.enable(pal.enums.PKESchemeFeature.PKE)
|
||||
cc.enable(pal.enums.PKESchemeFeature.KEYSWITCH)
|
||||
cc.enable(pal.enums.PKESchemeFeature.LEVELEDSHE)
|
||||
cc.enable(pal.enums.PKESchemeFeature.ADVANCEDSHE)
|
||||
cc.enable(pal.enums.PKESchemeFeature.FHE)
|
||||
|
||||
# generate keys
|
||||
keys = cc.keyGen()
|
||||
cc.evalMultKeyGen(keys.secretKey)
|
||||
cc.evalPowerOf2RotationKeyGen(keys.secretKey)
|
||||
|
||||
if bootstrapping:
|
||||
cc.evalBootstrapSetup()
|
||||
cc.evalBootstrapKeyGen(keys.secretKey)
|
||||
|
||||
if save:
|
||||
f = "{}-{}-{}-{}".format(batch_size, mult_depth, scale_factor_bits, int(bootstrapping))
|
||||
path = Path("serialized") / f
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
save_cc_and_keys(cc, keys, path)
|
||||
|
||||
return cc, keys
|
||||
|
||||
|
||||
def get_keys(mult_depth,
|
||||
scale_factor_bits,
|
||||
batch_size,
|
||||
bootstrapping):
|
||||
try:
|
||||
cc, keys = load_cc_and_keys(batch_size,
|
||||
mult_depth=mult_depth,
|
||||
scale_factor_bits=scale_factor_bits,
|
||||
bootstrapping=bootstrapping)
|
||||
except:
|
||||
cc, keys = create_cc_and_keys(batch_size,
|
||||
mult_depth=mult_depth,
|
||||
scale_factor_bits=scale_factor_bits,
|
||||
bootstrapping=bootstrapping,
|
||||
save=True)
|
||||
return cc, keys
|
||||
|
||||
|
||||
def get_filters_and_biases_from_conv2d(layer):
|
||||
filters = layer.weight.detach().numpy()
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
biases = layer.bias.detach().numpy()
|
||||
else:
|
||||
# without bias
|
||||
# same as number of output channels (each bias is broadcast over the channel)
|
||||
biases = np.zeros((filters.shape[0],))
|
||||
|
||||
filters = filters.transpose(1, 0, 2, 3)
|
||||
pad_to = next_power_of_2(filters.shape[0])
|
||||
|
||||
if pad_to is not None:
|
||||
if filters.shape[0] < pad_to:
|
||||
filters = np.concatenate(
|
||||
[filters, np.zeros((pad_to - filters.shape[0],) + filters.shape[1:])]
|
||||
)
|
||||
|
||||
return filters, biases
|
||||
|
||||
|
||||
def get_scale_and_shift_from_bn(layer):
|
||||
mu = layer.running_mean.detach().numpy()
|
||||
var = layer.running_var.detach().numpy()
|
||||
gamma = (
|
||||
layer.weight.detach().numpy()
|
||||
) # https://discuss.pytorch.org/t/getting-parameters-of-torch-nn-batchnorm2d-during-training/38913/3
|
||||
beta = layer.bias.detach().numpy()
|
||||
eps = layer.eps
|
||||
|
||||
sigma = np.sqrt(var + eps) # std dev
|
||||
|
||||
# compute scale factor
|
||||
scale = gamma / sigma
|
||||
|
||||
# compute shift factor
|
||||
shift = -gamma * mu / sigma + beta
|
||||
|
||||
return scale, shift
|
||||
|
||||
|
||||
# needs to know either number of channels or matrix size
|
||||
def get_weights_and_biases_from_linear(layer, mtx_size, bias, pool_factor=1):
|
||||
nout = layer.weight.size(0)
|
||||
weights = layer.weight.detach().numpy()
|
||||
num_channels = weights.shape[1] // (mtx_size * mtx_size)
|
||||
weights = weights.reshape(nout, num_channels, mtx_size, mtx_size)
|
||||
weights = weights.reshape(nout, -1)
|
||||
weights = np.repeat(weights, pool_factor * pool_factor, axis=1)
|
||||
|
||||
if bias:
|
||||
biases = layer.bias.detach().numpy()
|
||||
else:
|
||||
biases = np.zeros(nout)
|
||||
|
||||
return weights, biases
|
||||
|
||||
|
||||
# Given a model and an input, get intermediate layer output
|
||||
def get_intermediate_output(model, layer, inputs):
|
||||
layer_name = "layer"
|
||||
activation = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
activation[name] = output.detach()
|
||||
|
||||
return hook
|
||||
|
||||
layer.register_forward_hook(
|
||||
get_activation(layer_name)
|
||||
)
|
||||
_ = model(inputs)
|
||||
return activation[layer_name]
|
||||
|
||||
|
||||
def compare_accuracy(keys, cnn_context, unencrypted, name="block", num_digits=4):
|
||||
A = decrypt_and_reshape(cnn_context, keys.secretKey, cnn_context.mtx_size)
|
||||
B = unencrypted.detach().cpu().numpy()[0]
|
||||
diff = np.abs(A - B[cnn_context.permutation])
|
||||
print(f"error in {name}:\nmax = {np.max(diff):.0{num_digits}f}\nmean = {np.mean(diff):.0{num_digits}f}")
|
||||
|
||||
|
||||
def decrypt_and_reshape(cnn_context, secret_key, mtx_size):
|
||||
cc = secret_key.getCryptoContext()
|
||||
decrypted_output = [cc.decrypt(secret_key, ctxt) for ctxt in cnn_context.shards]
|
||||
decrypted_output = np.hstack(decrypted_output)
|
||||
num_out_chan = int(round(len(decrypted_output) / (mtx_size * mtx_size)))
|
||||
decrypted_output = decrypted_output.reshape((num_out_chan, mtx_size, mtx_size))
|
||||
decrypted_output = decrypted_output[0:: cnn_context.duplication_factor]
|
||||
|
||||
return decrypted_output
|
||||
|
||||
|
||||
def serialize(cc, keys, ctxt):
|
||||
path = Path("serialized")
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
shutil.rmtree(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
assert serial.SerializeToFile("serialized/CryptoContext.bin", ctxt, serial.SerType.BINARY)
|
||||
assert serial.SerializeToFile("serialized/ciphertext.bin", ctxt, serial.SerType.BINARY)
|
||||
assert serial.SerializeToFile("serialized/PublicKey.bin", keys.publicKey, serial.SerType.BINARY)
|
||||
assert serial.SerializeToFile("serialized/PrivateKey.bin", keys.secretKey, serial.SerType.BINARY)
|
||||
assert serial.SerializeToFile_EvalMultKey_CryptoContext(cc, "serialized/EvalMultKey.bin", serial.SerType.BINARY)
|
||||
assert serial.SerializeToFile_EvalAutomorphismKey_CryptoContext(cc, "serialized/EvalAutomorphismKey.bin",
|
||||
serial.SerType.BINARY)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cc, keys = get_keys(mult_depth=34, scale_factor_bits=59, batch_size=32 * 32 * 32, bootstrapping=True)
|
||||
print(cc.getBatchSize())
|
||||
shard = cc.encrypt(keys.publicKey, [0.0 for _ in range(32768)])
|
||||
serialize(cc, keys, shard)
|
||||
265
palisade_he_cnn/src/small_model.py
Normal file
265
palisade_he_cnn/src/small_model.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import datasets
|
||||
import torchvision.transforms as transforms
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
class Square(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return torch.square(x)
|
||||
|
||||
|
||||
def moment(x: torch.Tensor, std: float, mean: float, deg: int = 4, eps: float = 1e-4) -> torch.Tensor:
|
||||
N = x.shape[0]
|
||||
return (1.0 / N) * torch.sum((x - mean) ** deg) / (std ** deg + eps)
|
||||
|
||||
|
||||
def activation_helper(activation: str = 'gelu',
|
||||
gelu_degree: int = 16):
|
||||
if activation == 'relu':
|
||||
return nn.ReLU()
|
||||
elif activation == 'gelu':
|
||||
return nn.GELU()
|
||||
elif activation == 'polygelu':
|
||||
raise ValueError("Not supported.")
|
||||
elif activation == 'square':
|
||||
return Square()
|
||||
else:
|
||||
return nn.ReLU()
|
||||
|
||||
|
||||
def conv_block(in_ch: int,
|
||||
out_ch: int,
|
||||
activation: str = 'relu',
|
||||
gelu_degree: int = 16,
|
||||
pool: bool = False,
|
||||
pool_method: str = 'avg',
|
||||
kernel: int = 3,
|
||||
stride: int = 1,
|
||||
padding: Union[int, str] = 1):
|
||||
layers = [nn.Conv2d(in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel,
|
||||
stride=stride,
|
||||
padding=padding),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
activation_helper(activation, gelu_degree)
|
||||
]
|
||||
if pool:
|
||||
layers.append(nn.MaxPool2d(2, 2) if pool_method == 'max' else nn.AvgPool2d(2, 2))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def get_small_model_dict(activation='gelu',
|
||||
gelu_degree: int = 16,
|
||||
pool_method: str = 'avg') -> nn.ModuleDict:
|
||||
classifier = nn.Sequential(nn.Flatten(),
|
||||
nn.Linear(8 * 8 * 128, 10))
|
||||
return nn.ModuleDict(
|
||||
{
|
||||
"conv1": conv_block(in_ch=1,
|
||||
out_ch=64,
|
||||
kernel=4,
|
||||
pool=True,
|
||||
pool_method=pool_method,
|
||||
padding='same',
|
||||
activation=activation,
|
||||
gelu_degree=gelu_degree),
|
||||
"conv2": conv_block(in_ch=64,
|
||||
out_ch=128,
|
||||
kernel=4,
|
||||
pool=True,
|
||||
pool_method=pool_method,
|
||||
padding='same',
|
||||
activation=activation,
|
||||
gelu_degree=gelu_degree),
|
||||
"conv3": conv_block(in_ch=128,
|
||||
out_ch=128,
|
||||
kernel=4,
|
||||
pool=False,
|
||||
padding='same',
|
||||
activation=activation,
|
||||
gelu_degree=gelu_degree),
|
||||
"classifier": classifier
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SmallModel(nn.Module):
|
||||
|
||||
def __init__(self, activation='gelu', gelu_degree: int = 16, pool_method: str = 'avg'):
|
||||
super(SmallModel, self).__init__()
|
||||
self.model_layers = get_small_model_dict(activation=activation, gelu_degree=gelu_degree,
|
||||
pool_method=pool_method)
|
||||
self.n_bn_classes = self.count_instances_of_a_class()
|
||||
|
||||
def count_instances_of_a_class(self, cls: nn.BatchNorm2d = nn.BatchNorm2d) -> int:
|
||||
n_classes = 0
|
||||
for _, block in self.model_layers.items():
|
||||
for layer in block:
|
||||
# Handle the nested case
|
||||
if isinstance(layer, nn.Sequential):
|
||||
for sublayer in layer:
|
||||
if isinstance(sublayer, cls):
|
||||
n_classes += 1
|
||||
# Handle the unnested case
|
||||
else:
|
||||
if isinstance(layer, cls):
|
||||
n_classes += 1
|
||||
return n_classes
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
self.bn_outputs = {} # key=layer name, v=list of torch.Tensors
|
||||
self.outputs = {}
|
||||
|
||||
for name, block in self.model_layers.items():
|
||||
block_output, block_bn_output = self.block_pass(block, x)
|
||||
self.bn_outputs[name] = block_bn_output
|
||||
|
||||
# Residual Connection
|
||||
if "res" in name:
|
||||
x = x + block_output
|
||||
# Normal
|
||||
else:
|
||||
x = block_output
|
||||
|
||||
return x
|
||||
|
||||
def block_pass(self, block: nn.Sequential, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
bn_output = []
|
||||
|
||||
# Iterate through a block, which may be nested (residual connections are nested)
|
||||
for layer in block:
|
||||
# Handle the nested case
|
||||
if isinstance(layer, nn.Sequential):
|
||||
for sublayer in layer:
|
||||
x = sublayer(x)
|
||||
if isinstance(sublayer, nn.BatchNorm2d):
|
||||
bn_output.append(x)
|
||||
|
||||
# Handlle the unnested case
|
||||
else:
|
||||
x = layer(x)
|
||||
if isinstance(layer, nn.BatchNorm2d):
|
||||
bn_output.append(x)
|
||||
|
||||
self.outputs[layer] = x
|
||||
|
||||
return x, bn_output
|
||||
|
||||
# Must be called after forward method to set self.bn_outputs
|
||||
def get_bn_loss_metrics(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
means, stds, skews, kurts = self.get_moments_by_layer()
|
||||
|
||||
# Aggregating
|
||||
loss_means = F.mse_loss(means, torch.zeros(self.n_bn_classes))
|
||||
loss_stds = F.mse_loss(stds, torch.ones(self.n_bn_classes))
|
||||
# loss_skews = F.mse_loss(skews, torch.zeros(self.n_bn_classes))
|
||||
loss_kurts = F.mse_loss(kurts, 3 * torch.ones(self.n_bn_classes))
|
||||
return loss_means, loss_stds, loss_kurts
|
||||
|
||||
def get_moments_by_layer(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
means, stds = torch.zeros(self.n_bn_classes), torch.zeros(self.n_bn_classes)
|
||||
skews, kurts = torch.zeros(self.n_bn_classes), torch.zeros(self.n_bn_classes)
|
||||
|
||||
layer_index = 0
|
||||
for name, block in self.bn_outputs.items():
|
||||
|
||||
# Residual blocks are nested
|
||||
for sublayer in range(0, len(block), 1):
|
||||
dist = block[sublayer].flatten()
|
||||
std, mean = torch.std_mean(dist)
|
||||
skew = moment(dist, std, mean, deg=3)
|
||||
kurt = moment(dist, std, mean, deg=4)
|
||||
means[layer_index] = mean
|
||||
stds[layer_index] = std
|
||||
skews[layer_index] = skew
|
||||
kurts[layer_index] = kurt
|
||||
layer_index += 1
|
||||
return means, stds, skews, kurts
|
||||
|
||||
def get_intermediate_layer_output(self, layer, inputs):
|
||||
layer_name = "layer"
|
||||
activation = {}
|
||||
|
||||
def get_activation(name):
|
||||
def hook(self, input, output):
|
||||
print("calling hook")
|
||||
activation[name] = output.detach()
|
||||
|
||||
return hook
|
||||
|
||||
layer.register_forward_hook(
|
||||
get_activation(layer_name)
|
||||
)
|
||||
_ = self(inputs)
|
||||
return activation[layer_name]
|
||||
|
||||
def train_small_model():
|
||||
DATA_DIR = "../data"
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,)),
|
||||
transforms.Pad(2)
|
||||
])
|
||||
BATCH_SIZE = 512
|
||||
train_kwargs = {'batch_size': BATCH_SIZE}
|
||||
test_kwargs = {'batch_size': BATCH_SIZE}
|
||||
dataset1 = datasets.MNIST(DATA_DIR, train=True, download=True,
|
||||
transform=transform)
|
||||
dataset2 = datasets.MNIST(DATA_DIR, train=False,
|
||||
transform=transform)
|
||||
|
||||
train_dl = torch.utils.data.DataLoader(dataset1,**train_kwargs)
|
||||
val_dl = torch.utils.data.DataLoader(dataset2, **test_kwargs)
|
||||
max_lr = 0.005
|
||||
weight_decay = 1e-5
|
||||
model = SmallModel()
|
||||
optimizer = torch.optim.Adam(model.parameters(),
|
||||
max_lr,
|
||||
weight_decay = weight_decay)
|
||||
EPOCHS = 13
|
||||
|
||||
for epoch in range(EPOCHS):
|
||||
print(f"Epoch {epoch}")
|
||||
|
||||
train_loss = 0
|
||||
bn_means, bn_stds, bn_kurts = 0,0,0
|
||||
N = 0
|
||||
|
||||
model.train()
|
||||
|
||||
for i, (img, label) in enumerate(train_dl):
|
||||
logit = model(img)
|
||||
|
||||
# Model loss
|
||||
loss = F.cross_entropy(logit,label)
|
||||
|
||||
# Loss modifications
|
||||
bn_mean, bn_std, bn_kurt = model.get_bn_loss_metrics()
|
||||
loss += (bn_mean + bn_std + bn_kurt)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Save stuff
|
||||
train_loss += loss.item()
|
||||
bn_means += bn_mean.item()
|
||||
bn_stds += bn_std.item()
|
||||
bn_kurts += bn_kurt.item()
|
||||
N += 1
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
print("Saving model as %s" % "small_model.pt")
|
||||
torch.save(model.state_dict(), "small_model.pt")
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_small_model()
|
||||
93
palisade_he_cnn/src/small_model_inference.py
Normal file
93
palisade_he_cnn/src/small_model_inference.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
# export OMP_DISPLAY_ENV=TRUE
|
||||
import os
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from cnn_context import create_cnn_context
|
||||
from he_cnn.utils import *
|
||||
from small_model import SmallModel
|
||||
|
||||
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
|
||||
|
||||
# create HE cc and keys
|
||||
mult_depth = 30
|
||||
scale_factor_bits = 40
|
||||
batch_size = 32 * 32 * 32 # increased batch size b/c the ring dimension is higher due to the mult_depth
|
||||
|
||||
# used for a small test of big shards
|
||||
# batch_size = 128
|
||||
|
||||
# if using bootstrapping, you must increase scale_factor_bits to 59
|
||||
cc, keys = create_cc_and_keys(batch_size, mult_depth=mult_depth, scale_factor_bits=scale_factor_bits,
|
||||
bootstrapping=False)
|
||||
|
||||
# load the model
|
||||
weight_file = "palisade_he_cnn/src/weights/small_model.pt"
|
||||
print(os.getcwd())
|
||||
model = SmallModel(activation='gelu', pool_method='avg')
|
||||
model.load_state_dict(torch.load(weight_file))
|
||||
model.eval()
|
||||
|
||||
# load data
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,)),
|
||||
transforms.Pad(2)])
|
||||
validset = torchvision.datasets.MNIST(root="./data", download=True, transform=transform)
|
||||
validloader = torch.utils.data.DataLoader(validset, batch_size=1, shuffle=True)
|
||||
|
||||
total = 0
|
||||
correct = 0
|
||||
total_time = 0
|
||||
|
||||
for i, test_data in enumerate(validloader):
|
||||
print(f"Inference {i + 1}:")
|
||||
|
||||
x_test, y_test = test_data
|
||||
|
||||
input_img = create_cnn_context(x_test[0], cc, keys.publicKey, verbose=True)
|
||||
|
||||
start = time()
|
||||
|
||||
layer = model.model_layers.conv1
|
||||
conv1 = input_img.apply_conv(layer[0], layer[1])
|
||||
act1 = conv1.apply_gelu()
|
||||
pool1 = act1.apply_pool()
|
||||
|
||||
layer = model.model_layers.conv2
|
||||
perm = np.random.permutation(128) # example of how to use an output permutation
|
||||
conv2 = pool1.apply_conv(layer[0], layer[1], output_permutation=perm)
|
||||
act2 = conv2.apply_gelu()
|
||||
pool2 = act2.apply_pool()
|
||||
|
||||
layer = model.model_layers.conv3
|
||||
conv3 = pool2.apply_conv(layer[0], layer[1])
|
||||
act3 = conv3.apply_gelu()
|
||||
|
||||
layer = model.model_layers.classifier[1]
|
||||
logits = act3.apply_linear(layer)
|
||||
|
||||
logits_dec = cc.decrypt(keys.secretKey, logits)[:10]
|
||||
logits_pt = model(x_test).detach().numpy().ravel()
|
||||
|
||||
print(f"[+] decrypted logits = {logits_dec}")
|
||||
print(f"[+] unencrypted logits = {logits_pt}")
|
||||
|
||||
inference_time = time() - start
|
||||
total_time += inference_time
|
||||
total += 1
|
||||
|
||||
y_label = y_test[0]
|
||||
correct += np.argmax(logits_dec) == y_label
|
||||
|
||||
out_string = f"""
|
||||
Count: {total}
|
||||
Accuracy: {correct / total}
|
||||
Average latency: {total_time / total:.02f}s
|
||||
"""
|
||||
|
||||
print(out_string)
|
||||
138
palisade_he_cnn/src/utils.py
Normal file
138
palisade_he_cnn/src/utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import copy
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, Callable
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tt
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
|
||||
def pad_conv_input_channels(conv1):
|
||||
conv1 = copy.deepcopy(conv1)
|
||||
conv1.in_channels = 4
|
||||
data = conv1.weight.data
|
||||
shape = list(data.shape)
|
||||
shape[1] = 1
|
||||
padding = torch.zeros(*shape)
|
||||
new_data = torch.cat((data, padding), 1)
|
||||
conv1.weight.data = torch.Tensor(new_data)
|
||||
return conv1
|
||||
|
||||
|
||||
|
||||
class PadChannel(object):
|
||||
def __init__(self, npad: int=1):
|
||||
self.n = npad
|
||||
|
||||
def __call__(self, x):
|
||||
_, width, height = x.shape
|
||||
x = torch.cat([x, torch.zeros(self.n, width, height)])
|
||||
return x
|
||||
|
||||
def patch_whitening(data, patch_size=(3, 3)):
|
||||
# Compute weights from data such that
|
||||
# torch.std(F.conv2d(data, weights), dim=(2, 3))
|
||||
# is close to 1.
|
||||
h, w = patch_size
|
||||
c = data.size(1)
|
||||
patches = data.unfold(2, h, 1).unfold(3, w, 1)
|
||||
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
|
||||
|
||||
n, c, h, w = patches.shape
|
||||
X = patches.reshape(n, c * h * w)
|
||||
X = X / (X.size(0) - 1) ** 0.5
|
||||
covariance = X.t() @ X
|
||||
|
||||
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
|
||||
|
||||
eigenvalues = eigenvalues.flip(0)
|
||||
|
||||
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
|
||||
|
||||
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
|
||||
|
||||
def get_cifar10_dataloader(batch_size,
|
||||
data_dir: str='../../datasets/cifar10/',
|
||||
num_workers: int=4):
|
||||
stats = ((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010))
|
||||
|
||||
train_tfms = tt.Compose([
|
||||
tt.RandomCrop(32,padding=4,padding_mode='reflect'),
|
||||
tt.RandomHorizontalFlip(),
|
||||
tt.ToTensor(),
|
||||
tt.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1)
|
||||
])
|
||||
|
||||
val_tfms = tt.Compose([
|
||||
tt.ToTensor(),
|
||||
tt.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1)
|
||||
])
|
||||
|
||||
train_ds = ImageFolder(data_dir+'train',transform=train_tfms)
|
||||
val_ds = ImageFolder(data_dir+'test',transform=val_tfms)
|
||||
|
||||
train_dl = DataLoader(train_ds,
|
||||
batch_size,
|
||||
pin_memory = True,
|
||||
num_workers = num_workers,
|
||||
shuffle = True)
|
||||
val_dl = DataLoader(val_ds,
|
||||
batch_size,
|
||||
pin_memory = True,
|
||||
num_workers = num_workers)
|
||||
return train_dl, val_dl
|
||||
|
||||
|
||||
def remove_all_hooks(model: torch.nn.Module) -> None:
|
||||
for name, child in model._modules.items():
|
||||
if child is not None:
|
||||
if hasattr(child, "_forward_hooks"):
|
||||
child._forward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_forward_pre_hooks"):
|
||||
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_backward_hooks"):
|
||||
child._backward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
remove_all_hooks(child)
|
||||
|
||||
# Given a model and an input, get intermediate layer output
|
||||
def get_intermediate_output(model):
|
||||
activation = defaultdict(list)
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
x = output.detach()
|
||||
activation[name].append(x)
|
||||
|
||||
return hook
|
||||
|
||||
BatchNorm_layers = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
for i, b in enumerate(BatchNorm_layers):
|
||||
b.register_forward_hook(
|
||||
get_activation(f"bn_{i + 1}")
|
||||
)
|
||||
return activation
|
||||
|
||||
|
||||
def get_all_bn_activations(model, val_dl, DEVICE):
|
||||
activation = get_intermediate_output(model)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
for img, label in (val_dl):
|
||||
img, label = img.to(DEVICE), label.to(DEVICE)
|
||||
out = model(img)
|
||||
|
||||
remove_all_hooks(model)
|
||||
|
||||
activation = {k:torch.cat(v) for k,v in activation.items()}
|
||||
|
||||
return activation
|
||||
|
||||
|
||||
BIN
palisade_he_cnn/src/weights/small_model.pt
Normal file
BIN
palisade_he_cnn/src/weights/small_model.pt
Normal file
Binary file not shown.
248
palisade_he_cnn/test.py
Normal file
248
palisade_he_cnn/test.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from src.cnn_context import create_cnn_context
|
||||
from src.he_cnn.utils import *
|
||||
|
||||
np.set_printoptions(formatter={'float': lambda x: "{0:0.3f}".format(x)})
|
||||
|
||||
class Info():
|
||||
def __init__(self, mult_depth = 30, scale_factor_bits = 40, batch_size = 32 * 32 * 32, max = 255, min = 0, h = 128, w = 128, channel_size = 3, ker_size = 3):
|
||||
self.mult_depth = mult_depth
|
||||
self.scale_factor_bits = scale_factor_bits
|
||||
self.batch_size = batch_size
|
||||
self.max = max
|
||||
self.min = min
|
||||
self.h = h
|
||||
self.w = w
|
||||
self.channel_size = channel_size
|
||||
self.ker_size = ker_size
|
||||
|
||||
rand_tensor = (max-min)*torch.rand((channel_size, h, w)) + min
|
||||
self.rand_tensor = rand_tensor
|
||||
|
||||
self.cc, self.keys = create_cc_and_keys(batch_size, mult_depth=mult_depth, scale_factor_bits=scale_factor_bits, bootstrapping=False)
|
||||
|
||||
self.input_img = create_cnn_context(self.rand_tensor, self.cc, self.keys.publicKey, verbose=True)
|
||||
|
||||
@pytest.fixture
|
||||
def check1():
|
||||
return Info(30, 40, 32 * 32 * 32, 1, -1, 64, 64, 4, 3)
|
||||
|
||||
@pytest.fixture
|
||||
def check2():
|
||||
return Info(30, 40, 32 * 32 * 32, 1, -1, 64, 64, 1, 3)
|
||||
|
||||
@pytest.fixture
|
||||
def check3():
|
||||
return Info(30, 40, 32, 1, -1, 16, 16, 2, 3)
|
||||
|
||||
def test_apply_conv2d_c1(check1) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check1.channel_size, check1.channel_size, check1.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check1.rand_tensor)
|
||||
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
|
||||
|
||||
conv1 = check1.input_img.apply_conv(layer)
|
||||
dec_conv1 = conv1.decrypt_to_tensor(check1.cc, check1.keys).numpy().squeeze()
|
||||
|
||||
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed image < shard"
|
||||
|
||||
def test_apply_conv2d_c2(check2) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check2.channel_size, check2.channel_size, check2.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check2.rand_tensor)
|
||||
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
|
||||
|
||||
conv1 = check2.input_img.apply_conv(layer)
|
||||
dec_conv1 = conv1.decrypt_to_tensor(check2.cc, check2.keys).numpy().squeeze()
|
||||
|
||||
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed channel < shard"
|
||||
|
||||
def test_apply_conv2d_c3(check3) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check3.channel_size, check3.channel_size, check3.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check3.rand_tensor)
|
||||
pt_conv = torch.squeeze(pt_conv, axis=0).detach().numpy()
|
||||
|
||||
conv1 = check3.input_img.apply_conv(layer)
|
||||
dec_conv1 = conv1.decrypt_to_tensor(check3.cc, check3.keys).numpy().squeeze()
|
||||
|
||||
assert np.allclose(dec_conv1, pt_conv, atol=1e-03), "Convolution result did not match between HE and PyTorch, failed channel > shard"
|
||||
|
||||
|
||||
def test_apply_pool_c1(check1) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check1.channel_size, check1.channel_size, check1.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check1.rand_tensor)
|
||||
pt_max_pool = torch.nn.AvgPool2d(2)
|
||||
pt_pool = pt_max_pool(pt_conv)
|
||||
pt_pool = pt_pool.detach().numpy()
|
||||
|
||||
conv1 = check1.input_img.apply_conv(layer)
|
||||
pool = conv1.apply_pool()
|
||||
dec_pool = pool.decrypt_to_tensor(check1.cc, check1.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed image < shard"
|
||||
|
||||
def test_apply_pool_c2(check2) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check2.channel_size, check2.channel_size, check2.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check2.rand_tensor)
|
||||
pt_max_pool = torch.nn.AvgPool2d(2)
|
||||
pt_pool = pt_max_pool(pt_conv)
|
||||
pt_pool = pt_pool.detach().numpy()
|
||||
|
||||
conv1 = check2.input_img.apply_conv(layer)
|
||||
pool = conv1.apply_pool()
|
||||
dec_pool = pool.decrypt_to_tensor(check2.cc, check2.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed channel < shard"
|
||||
|
||||
def test_apply_pool_c3(check3) -> None:
|
||||
class ConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size):
|
||||
super(ConvLayer, self).__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, padding="same")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
model = ConvLayer(check3.channel_size, check3.channel_size, check3.ker_size)
|
||||
model.eval()
|
||||
layer = model.conv
|
||||
|
||||
pt_conv = model(check3.rand_tensor)
|
||||
pt_max_pool = torch.nn.AvgPool2d(2)
|
||||
pt_pool = pt_max_pool(pt_conv)
|
||||
pt_pool = pt_pool.detach().numpy()
|
||||
|
||||
conv1 = check3.input_img.apply_conv(layer)
|
||||
pool = conv1.apply_pool()
|
||||
dec_pool = pool.decrypt_to_tensor(check3.cc, check3.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_pool, pt_pool, atol=1e-03), "Pooling result did not match between HE and PyTorch, failed channel > shard"
|
||||
|
||||
def test_apply_linear_c1(check1) -> None:
|
||||
class LinearLayer(torch.nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super(LinearLayer, self).__init__()
|
||||
self.linear_one = torch.nn.Linear(input_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_one(x)
|
||||
return x
|
||||
|
||||
linear = LinearLayer(len(check1.rand_tensor.flatten()), check1.rand_tensor.shape[0])
|
||||
linear.eval()
|
||||
pt_linear = linear(check1.rand_tensor.flatten()).detach().numpy()
|
||||
|
||||
he_linear = check1.input_img.apply_linear(linear.linear_one)
|
||||
dec_linear = check1.cc.decrypt(check1.keys.secretKey, he_linear)[0:check1.rand_tensor.shape[0]]
|
||||
|
||||
assert np.allclose(dec_linear, pt_linear, atol=1e-03), "Linear result did not match between HE and PyTorch, failed image < shard"
|
||||
|
||||
def test_apply_linear_c2(check2) -> None:
|
||||
class LinearLayer(torch.nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super(LinearLayer, self).__init__()
|
||||
self.linear_one = torch.nn.Linear(input_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_one(x)
|
||||
return x
|
||||
|
||||
linear = LinearLayer(len(check2.rand_tensor.flatten()), check2.rand_tensor.shape[0])
|
||||
linear.eval()
|
||||
pt_linear = linear(check2.rand_tensor.flatten()).detach().numpy()
|
||||
|
||||
he_linear = check2.input_img.apply_linear(linear.linear_one)
|
||||
dec_linear = check2.cc.decrypt(check2.keys.secretKey, he_linear)[0:check2.rand_tensor.shape[0]]
|
||||
|
||||
assert np.allclose(dec_linear, pt_linear, atol=1e-03), "Linear result did not match between HE and PyTorch, failed channel < shard"
|
||||
|
||||
def test_apply_gelu_c1(check1) -> None:
|
||||
gelu = torch.nn.GELU()
|
||||
pt_gelu = gelu(check1.rand_tensor)
|
||||
|
||||
he_gelu = check1.input_img.apply_gelu()
|
||||
dec_gelu = he_gelu.decrypt_to_tensor(check1.cc, check1.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed image < shard"
|
||||
|
||||
def test_apply_gelu_c2(check2) -> None:
|
||||
gelu = torch.nn.GELU()
|
||||
pt_gelu = gelu(check2.rand_tensor)
|
||||
|
||||
he_gelu = check2.input_img.apply_gelu()
|
||||
dec_gelu = he_gelu.decrypt_to_tensor(check2.cc, check2.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed channel < shard"
|
||||
|
||||
def test_apply_gelu_c3(check3) -> None:
|
||||
gelu = torch.nn.GELU()
|
||||
pt_gelu = gelu(check3.rand_tensor)
|
||||
|
||||
he_gelu = check3.input_img.apply_gelu()
|
||||
dec_gelu = he_gelu.decrypt_to_tensor(check3.cc, check3.keys).numpy()
|
||||
|
||||
assert np.allclose(dec_gelu, pt_gelu, atol=1e-03), "GELU result did not match between HE and PyTorch, failed channel > shard"
|
||||
75
palisade_he_cnn/training/models/resnet50.py
Normal file
75
palisade_he_cnn/training/models/resnet50.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import ImageFile
|
||||
from tqdm import tqdm
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
# Set device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
device_ids = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
# Set hyperparameters
|
||||
num_epochs = 1
|
||||
batch_size = 128
|
||||
learning_rate = 0.001
|
||||
|
||||
# Initialize transformations for data augmentation
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomVerticalFlip(),
|
||||
transforms.RandomRotation(degrees=45),
|
||||
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# Load the ImageNet Object Localization Challenge dataset
|
||||
train_dataset = torchvision.datasets.ImageFolder(
|
||||
root='~/ImageNet/ILSVRC/Data/CLS-LOC/train',
|
||||
transform=transform
|
||||
)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
|
||||
|
||||
# Load the ResNet50 model
|
||||
model = torchvision.models.resnet50(weights='DEFAULT')
|
||||
|
||||
# Parallelize training across multiple GPUs
|
||||
model = torch.nn.DataParallel(model, device_ids = device_ids)
|
||||
|
||||
# Set the model to run on the device
|
||||
model = model.to(f'cuda:{model.device_ids[0]}')
|
||||
|
||||
# Define the loss function and optimizer
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Train the model...
|
||||
for epoch in range(num_epochs):
|
||||
for inputs, labels in tqdm(train_loader):
|
||||
# Move input and label tensors to the device
|
||||
inputs = inputs.to(f'cuda:{model.device_ids[0]}')
|
||||
labels = labels.to(f'cuda:{model.device_ids[0]}')
|
||||
|
||||
# Zero out the optimizer
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Print the loss for every epoch
|
||||
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}')
|
||||
|
||||
print(f'Finished Training, Loss: {loss.item():.4f}')
|
||||
130
palisade_he_cnn/training/models/resnet9.py
Normal file
130
palisade_he_cnn/training/models/resnet9.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, scale: float = 0.125):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
return self.scale * x
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_ch: int,
|
||||
out_ch: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: str = "same",
|
||||
pool: bool = False,
|
||||
gelu: bool = False
|
||||
):
|
||||
layers = [nn.Conv2d(in_ch,
|
||||
out_ch,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding),
|
||||
nn.BatchNorm2d(out_ch)
|
||||
]
|
||||
if pool:
|
||||
layers.append(nn.AvgPool2d(2, 2))
|
||||
|
||||
if gelu:
|
||||
layers.append(nn.GELU())
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class ResNet9(nn.Module):
|
||||
def __init__(self,
|
||||
c_in: int = 4,
|
||||
c_out: int = 36,
|
||||
num_classes: int = 10,
|
||||
scale_out: float = 0.125):
|
||||
super().__init__()
|
||||
self.c_out = c_out
|
||||
self.conv1 = nn.Conv2d(c_in,
|
||||
c_out,
|
||||
kernel_size=(3, 3),
|
||||
padding="same",
|
||||
bias=True)
|
||||
self.conv2 = conv_block(c_out,
|
||||
64,
|
||||
kernel_size=1,
|
||||
padding="same",
|
||||
pool=False,
|
||||
gelu=True)
|
||||
self.conv3 = conv_block(64,
|
||||
128,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=True,
|
||||
gelu=True)
|
||||
self.res1 = nn.Sequential(
|
||||
conv_block(128,
|
||||
128,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=False,
|
||||
gelu=True),
|
||||
conv_block(128,
|
||||
128,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=False,
|
||||
gelu=True)
|
||||
|
||||
)
|
||||
self.conv4 = conv_block(128,
|
||||
256,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=True,
|
||||
gelu=True)
|
||||
self.conv5 = conv_block(256,
|
||||
512,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=True,
|
||||
gelu=True)
|
||||
self.res2 = nn.Sequential(
|
||||
conv_block(512,
|
||||
512,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=False,
|
||||
gelu=True),
|
||||
conv_block(512,
|
||||
512,
|
||||
kernel_size=3,
|
||||
padding="same",
|
||||
pool=False,
|
||||
gelu=True)
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(512 * 4 * 4, num_classes, bias=True),
|
||||
Scale(scale_out)
|
||||
)
|
||||
|
||||
def set_conv1_weights(self,
|
||||
weights: torch.Tensor,
|
||||
bias: torch.Tensor):
|
||||
self.conv1.weight.data = weights
|
||||
self.conv1.weight.requires_grad = False
|
||||
self.conv1.bias.data = bias
|
||||
self.conv1.bias.requires_grad = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
res1 = self.res1(x)
|
||||
x = x + res1
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
res2 = self.res2(x)
|
||||
x = x + res2
|
||||
return self.classifier(x)
|
||||
321
palisade_he_cnn/training/models/resnetN_multiplexed.py
Normal file
321
palisade_he_cnn/training/models/resnetN_multiplexed.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
# Low-Complexity deep convolutional neural networks on
|
||||
# FHE using multiplexed parallel convolutions
|
||||
#
|
||||
# https://eprint.iacr.org/2021/1688.pdf
|
||||
#
|
||||
# Our implementation is not an exact 1-to-1
|
||||
|
||||
__all__ = [
|
||||
"resnet_test",
|
||||
"resnet20",
|
||||
"resnet32",
|
||||
"resnet44",
|
||||
"resnet56",
|
||||
"resnet110",
|
||||
]
|
||||
|
||||
POOL = nn.AvgPool2d(2, 2)
|
||||
BN_MOMENTUM = 0.1
|
||||
|
||||
|
||||
class Debug(nn.Module):
|
||||
def __init__(self, filename="temp.txt", debug=False):
|
||||
super().__init__()
|
||||
self.filename = Path("debug") / filename
|
||||
self.debug = debug
|
||||
# print(self.debug, filename)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
if self.debug:
|
||||
data = x.detach().cpu().numpy().ravel()
|
||||
np.savetxt(self.filename, data, fmt="%0.04f")
|
||||
return x
|
||||
|
||||
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, scale: float = 0.125):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
return self.scale * x
|
||||
|
||||
|
||||
def conv_bn(inchan: int,
|
||||
outchan: int,
|
||||
kernel: int = 3,
|
||||
stride: int = 1,
|
||||
padding: str = "same",
|
||||
filenames: list = ["temp.txt"],
|
||||
debug: bool = False) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inchan,
|
||||
outchan,
|
||||
kernel_size=kernel,
|
||||
stride=1,
|
||||
padding=padding
|
||||
),
|
||||
nn.BatchNorm2d(outchan, momentum=BN_MOMENTUM),
|
||||
Debug(filenames[0], debug)
|
||||
)
|
||||
|
||||
|
||||
def conv_bn_down(inchan: int,
|
||||
outchan: int,
|
||||
kernel: int = 3,
|
||||
stride: int = 1,
|
||||
padding: str = "same",
|
||||
filenames: list = ["temp.txt", "temp.txt"],
|
||||
debug: bool = False, ) -> nn.Sequential:
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inchan,
|
||||
outchan,
|
||||
kernel_size=kernel,
|
||||
stride=1,
|
||||
padding=padding
|
||||
),
|
||||
nn.BatchNorm2d(outchan, momentum=BN_MOMENTUM),
|
||||
Debug(filenames[0], debug),
|
||||
POOL,
|
||||
Debug(filenames[1], debug)
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
inchan: int,
|
||||
outchan: int,
|
||||
kernel: int = 3,
|
||||
stride: int = 1,
|
||||
padding: str = "same",
|
||||
activation: nn.Module = nn.GELU(),
|
||||
downsample: Optional[nn.Module] = None,
|
||||
prefix: str = "l1",
|
||||
debug: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# If a skip module is defined (defined as downsample), then our first block
|
||||
# in series needs to also include a downsampling operation, aka pooling.
|
||||
if downsample is not None:
|
||||
self.conv_bn_1 = conv_bn_down(
|
||||
inchan=inchan,
|
||||
outchan=outchan,
|
||||
kernel=3,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
filenames=["%s_bn1.txt" % prefix,
|
||||
"%s_pool.txt" % prefix],
|
||||
debug=debug
|
||||
)
|
||||
|
||||
else:
|
||||
self.conv_bn_1 = conv_bn(
|
||||
inchan=outchan,
|
||||
outchan=outchan,
|
||||
kernel=3,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
filenames=["%s_bn1.txt" % prefix],
|
||||
debug=debug
|
||||
)
|
||||
|
||||
self.conv_bn_2 = conv_bn(
|
||||
inchan=outchan,
|
||||
outchan=outchan,
|
||||
kernel=3,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
filenames=["%s_bn2.txt" % prefix],
|
||||
debug=debug
|
||||
)
|
||||
|
||||
self.gelu = activation
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv_bn_1(x)
|
||||
out = self.gelu(out)
|
||||
|
||||
out = self.conv_bn_2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
|
||||
if self.downsample is not None:
|
||||
out = self.gelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[BasicBlock],
|
||||
layers: List[int],
|
||||
num_classes: int = 10,
|
||||
debug: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
self.debug = debug
|
||||
|
||||
self.conv_bn_1 = nn.Sequential(
|
||||
nn.Conv2d(4, 16, kernel_size=3, stride=1, padding="same"),
|
||||
nn.BatchNorm2d(16, momentum=BN_MOMENTUM),
|
||||
Debug("l0_bn1.txt", debug=debug)
|
||||
)
|
||||
self.gelu0 = nn.GELU()
|
||||
self.gelu1 = nn.GELU()
|
||||
self.gelu2 = nn.GELU()
|
||||
self.gelu3 = nn.GELU()
|
||||
|
||||
self.debug0 = Debug("l0_gelu.txt", debug)
|
||||
self.debug1 = Debug("l1_gelu.txt", debug)
|
||||
self.debug2 = Debug("l2_gelu.txt", debug)
|
||||
self.debug3 = Debug("l3_gelu.txt", debug)
|
||||
|
||||
self.layer1 = self._make_layer(
|
||||
block=block,
|
||||
inchan=16,
|
||||
outchan=16,
|
||||
nblocks=layers[0],
|
||||
stride=1,
|
||||
prefix="l1"
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
inchan=16,
|
||||
outchan=32,
|
||||
nblocks=layers[1],
|
||||
stride=2, # Triggers downsample != None, not a true stride
|
||||
prefix="l2"
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
inchan=32,
|
||||
outchan=32,
|
||||
nblocks=layers[2],
|
||||
stride=2, # Triggers downsample != None, not a true stride
|
||||
prefix="l3"
|
||||
)
|
||||
self.classifier = nn.Sequential(
|
||||
POOL,
|
||||
nn.Flatten(),
|
||||
nn.Linear(32 * 4 * 4, num_classes),
|
||||
Scale()
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
x = self.conv_bn_1(x)
|
||||
x = self.gelu0(x)
|
||||
x = self.debug0(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.gelu1(x)
|
||||
x = self.debug1(x)
|
||||
|
||||
x = self.layer2(x)
|
||||
x = self.gelu2(x)
|
||||
x = self.debug2(x)
|
||||
|
||||
x = self.layer3(x)
|
||||
x = self.gelu3(x)
|
||||
x = self.debug3(x)
|
||||
|
||||
return self.classifier(x)
|
||||
|
||||
def _make_layer(
|
||||
self,
|
||||
block: Type[BasicBlock],
|
||||
inchan: int,
|
||||
outchan: int,
|
||||
nblocks: int,
|
||||
stride: int = 1,
|
||||
prefix: str = "l1"
|
||||
) -> nn.Sequential:
|
||||
|
||||
downsample = None
|
||||
|
||||
if stride != 1:
|
||||
downsample = conv_bn_down(
|
||||
inchan=inchan,
|
||||
outchan=outchan,
|
||||
filenames=["%s_ds_bn1.txt" % prefix,
|
||||
"%s_ds_pool.txt" % prefix],
|
||||
debug=self.debug,
|
||||
kernel=3,
|
||||
stride=1,
|
||||
padding="same"
|
||||
)
|
||||
|
||||
layers = []
|
||||
for i in range(0, nblocks):
|
||||
# Only need it for first iter
|
||||
if i == 1:
|
||||
downsample = None
|
||||
|
||||
layers.append(
|
||||
block(
|
||||
inchan=inchan,
|
||||
outchan=outchan,
|
||||
kernel=3,
|
||||
stride=stride,
|
||||
padding="same",
|
||||
activation=nn.GELU(),
|
||||
downsample=downsample,
|
||||
prefix=prefix + "_%s" % str(i),
|
||||
debug=self.debug
|
||||
)
|
||||
)
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def _resnet(
|
||||
block: Type[BasicBlock],
|
||||
layers: List[int],
|
||||
**kwargs: Any,
|
||||
) -> ResNet:
|
||||
return ResNet(block, layers, **kwargs)
|
||||
|
||||
|
||||
def resnet_test(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [1, 1, 1], **kwargs)
|
||||
|
||||
|
||||
def resnet20(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [3, 3, 3], **kwargs)
|
||||
|
||||
|
||||
def resnet32(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [5, 5, 5], **kwargs)
|
||||
|
||||
|
||||
def resnet44(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [7, 7, 7], **kwargs)
|
||||
|
||||
|
||||
def resnet56(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [9, 9, 9], **kwargs)
|
||||
|
||||
|
||||
def resnet110(**kwargs: Any) -> ResNet:
|
||||
return _resnet(BasicBlock, [18, 18, 18], **kwargs)
|
||||
68
palisade_he_cnn/training/optuna_params.py
Normal file
68
palisade_he_cnn/training/optuna_params.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
def get_optuna_params(model_type: str, dataset: str) -> dict:
|
||||
params = {}
|
||||
if dataset=='CIFAR10':
|
||||
if model_type=='resnet20':
|
||||
params["lr"] = 0.0016822249163093617
|
||||
params["lr_bias"] = 63.934695046801245
|
||||
params["momentum"] = 0.8484574950771097
|
||||
params["weight_decay"] = 0.11450934135118791
|
||||
elif model_type=='resnet32':
|
||||
# 91.19: lr: 0.0013205254360784781, lr_bias: 61.138281101282544, momentum: 0.873508553678625, weight_decay: 0.26911634559915815
|
||||
# 91.1 : lr: 0.0013978655308274968, lr_bias: 70.43940111170473, momentum: 0.8611100787383372, weight_decay: 0.2604742590264777
|
||||
# 90.99: lr: 0.0019695910893940986, lr_bias: 60.930501987151686, momentum: 0.8831260271578129, weight_decay: 0.1456126229025426
|
||||
params["lr"] = 0.0013205254360784781
|
||||
params["lr_bias"] = 61.138281101282544
|
||||
params["momentum"] = 0.873508553678625
|
||||
params["weight_decay"] = 0.26911634559915815
|
||||
elif model_type=='resnet44':
|
||||
# 91.49: 0.0017177668853317557 72.4258603207131 0.8353896320183106 0.16749858871622
|
||||
# 91.16: 0.0019608745758959625 67.9132255882833 0.8041541468923449 0.19517278422517992
|
||||
# 91.01: 0.0009350979452929332 71.95838038824016 0.858379476548086 0.06780300392316674
|
||||
params["lr"] = 0.0017177668853317557
|
||||
params["lr_bias"] = 72.4258603207131
|
||||
params["momentum"] = 0.8353896320183106
|
||||
params["weight_decay"] = 0.16749858871622
|
||||
elif model_type=='resnet56':
|
||||
# 92.12: 0.0012022823706985977 71.31108702685964 0.8252747623136261 0.26463818739336625
|
||||
# 91.90: 0.0010850336892236205 55.20534833175523 0.8738224946147084 0.10705317777179325
|
||||
# 91.56: 0.0019151327847040805 63.38376732305882 0.9134938189630787 0.24446065595718675
|
||||
params["lr"] = 0.0012022823706985977
|
||||
params["lr_bias"] = 71.31108702685964
|
||||
params["momentum"] = 0.8252747623136261
|
||||
params["weight_decay"] = 0.26463818739336625
|
||||
elif model_type=='resnet110':
|
||||
# 92.23: 0.001477698037686629 61.444988882569774 0.7241645867415002 0.23586225065185779
|
||||
# 92.17: 0.0017110807237653582 65.2511959805971 0.8078620231092996 0.19065715813207001
|
||||
# 92.16: 0.0015513227695282382 59.89497310126697 0.7355843250067341 0.13248840913478463
|
||||
params["lr"] = 0.001477698037686629
|
||||
params["lr_bias"] = 61.444988882569774
|
||||
params["momentum"] = 0.7241645867415002
|
||||
params["weight_decay"] = 0.23586225065185779
|
||||
else:
|
||||
print("model_type and dataset are incorrectly specified. Returning resnet20 params.")
|
||||
params["lr"] = 0.0016822249163093617
|
||||
params["lr_bias"] = 63.934695046801245
|
||||
params["momentum"] = 0.8484574950771097
|
||||
params["weight_decay"] = 0.11450934135118791
|
||||
else:
|
||||
# only return resnet32 since CIFAR100 only needs this
|
||||
if model_type=='resnet32':
|
||||
# 65.09: 0.0018636209167742187 64.96657354785438 0.9186032548289501 0.15017464467868924
|
||||
# 64.70: 0.0017509006966116355 60.10884856596049 0.8921508582343675 0.10919043636429121
|
||||
# 64.49: 0.0015358614659175514 59.175398449172015 0.8553794786037812 0.20824545084283141
|
||||
params["lr"] = 0.0018636209167742187
|
||||
params["lr_bias"] = 64.96657354785438
|
||||
params["momentum"] = 0.9186032548289501
|
||||
params["weight_decay"] = 0.15017464467868924
|
||||
else:
|
||||
# Default bag of trick params
|
||||
params["lr"] = 0.001
|
||||
params["lr_bias"] = 64
|
||||
params["momentum"] = 0.9
|
||||
params["weight_decay"] = 0.256
|
||||
|
||||
print("Loading params for %s, %s" % (model_type, dataset))
|
||||
print(params)
|
||||
return params
|
||||
308
palisade_he_cnn/training/train_resnet9.py
Normal file
308
palisade_he_cnn/training/train_resnet9.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
|
||||
from palisade_he_cnn.training.models.resnet9 import ResNet9
|
||||
from palisade_he_cnn.training.utils.utils_dataloading import *
|
||||
from palisade_he_cnn.training.utils.utils_kurtosis import *
|
||||
from palisade_he_cnn.training.utils.utils_resnetN import (
|
||||
patch_whitening, update_nesterov, update_ema, label_smoothing_loss
|
||||
)
|
||||
|
||||
|
||||
def argparsing():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-bs', '--batch',
|
||||
help='Batch size',
|
||||
type=int,
|
||||
required=False,
|
||||
default=512)
|
||||
parser.add_argument('-e', '--epochs',
|
||||
help='Number of epochs',
|
||||
type=int,
|
||||
required=False,
|
||||
default=100)
|
||||
parser.add_argument('-c', '--cuda',
|
||||
help='CUDA device number',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0)
|
||||
parser.add_argument('-r', '--nruns',
|
||||
help='Number of training runs',
|
||||
type=int,
|
||||
required=False,
|
||||
default=5)
|
||||
parser.add_argument('-dataset', '--dataset',
|
||||
help='CIFAR10 or CIFAR100',
|
||||
type=str,
|
||||
choices=['CIFAR10', 'CIFAR100'],
|
||||
required=False,
|
||||
default='CIFAR10')
|
||||
parser.add_argument('-s', '--save',
|
||||
help='Save model and log files',
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True)
|
||||
|
||||
return vars(parser.parse_args())
|
||||
|
||||
|
||||
def train(
|
||||
dataset,
|
||||
epochs,
|
||||
batch_size,
|
||||
momentum,
|
||||
weight_decay,
|
||||
weight_decay_bias,
|
||||
ema_update_freq,
|
||||
ema_rho,
|
||||
device,
|
||||
dtype,
|
||||
kwargs,
|
||||
use_TTA,
|
||||
seed=0
|
||||
):
|
||||
lr_schedule = torch.cat([
|
||||
torch.linspace(0e+0, 2e-3, 194),
|
||||
torch.linspace(2e-3, 2e-4, 582),
|
||||
])
|
||||
|
||||
lr_schedule_bias = 64.0 * lr_schedule
|
||||
|
||||
kurt_schedule = torch.cat([
|
||||
torch.linspace(0, 1e-1, 2000),
|
||||
])
|
||||
|
||||
# Print information about hardware on first run
|
||||
if seed == 0:
|
||||
if device.type == "cuda":
|
||||
print("Device :", torch.cuda.get_device_name(device.index))
|
||||
|
||||
print("Dtype :", dtype)
|
||||
print()
|
||||
|
||||
# Start measuring time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Set random seed to increase chance of reproducability
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Setting cudnn.benchmark to True hampers reproducability, but is faster
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Load dataset
|
||||
if dataset == "CIFAR10":
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
|
||||
else:
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
|
||||
|
||||
train_data = torch.cat(
|
||||
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
|
||||
valid_data = torch.cat(
|
||||
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
|
||||
|
||||
temp = train_data[:10000, :, 4:-4, 4:-4]
|
||||
weights = patch_whitening(temp)
|
||||
|
||||
train_model = ResNet9(c_in=weights.size(1),
|
||||
c_out=weights.size(0),
|
||||
**kwargs).to(device)
|
||||
train_model.set_conv1_weights(
|
||||
weights=weights.to(device),
|
||||
bias=torch.zeros(weights.size(0)).to(device)
|
||||
)
|
||||
train_model.to(dtype)
|
||||
|
||||
# Convert BatchNorm back to single precision for better accuracy
|
||||
for module in train_model.modules():
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
module.float()
|
||||
|
||||
# Collect weights and biases and create nesterov velocity values
|
||||
weights = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) > 1
|
||||
]
|
||||
biases = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) <= 1
|
||||
]
|
||||
|
||||
# Copy the model for validation
|
||||
valid_model = copy.deepcopy(train_model)
|
||||
|
||||
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
# Train and validate
|
||||
print("\nepoch batch train time [sec] validation accuracy")
|
||||
train_time = 0.0
|
||||
batch_count = 0
|
||||
best_acc = 0.0
|
||||
best_model = None
|
||||
for epoch in range(1, epochs + 1):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Randomly shuffle training data
|
||||
indices = torch.randperm(len(train_data), device=device)
|
||||
data = train_data[indices]
|
||||
targets = train_targets[indices]
|
||||
|
||||
# Crop random 32x32 patches from 40x40 training data
|
||||
data = [
|
||||
random_crop(data[i: i + batch_size], crop_size=(32, 32))
|
||||
for i in range(0, len(data), batch_size)
|
||||
]
|
||||
data = torch.cat(data)
|
||||
|
||||
# Randomly flip half the training data
|
||||
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
|
||||
|
||||
for i in range(0, len(data), batch_size):
|
||||
# discard partial batches
|
||||
if i + batch_size > len(data):
|
||||
break
|
||||
|
||||
# Slice batch from data
|
||||
inputs = data[i: i + batch_size]
|
||||
target = targets[i: i + batch_size]
|
||||
batch_count += 1
|
||||
|
||||
# Compute new gradients
|
||||
train_model.zero_grad()
|
||||
train_model.train(True)
|
||||
|
||||
# kurtosis setup
|
||||
remove_all_hooks(train_model)
|
||||
activations = get_intermediate_output(train_model)
|
||||
|
||||
logits = train_model(inputs)
|
||||
loss = label_smoothing_loss(logits, target, alpha=0.2)
|
||||
|
||||
# kurtosis scheduler
|
||||
kurt_index = min(batch_count, len(kurt_schedule) - 1)
|
||||
kurt_scale = kurt_schedule[kurt_index]
|
||||
|
||||
# kurtosis calculation and cleanup
|
||||
remove_all_hooks(train_model)
|
||||
activations = {k: torch.cat(v) for k, v in activations.items()}
|
||||
loss_means, loss_stds, loss_kurts = get_statistics(activations)
|
||||
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
|
||||
|
||||
loss.sum().backward()
|
||||
|
||||
lr_index = min(batch_count, len(lr_schedule) - 1)
|
||||
lr = lr_schedule[lr_index]
|
||||
lr_bias = lr_schedule_bias[lr_index]
|
||||
|
||||
# Update weights and biases of training model
|
||||
update_nesterov(weights, lr, weight_decay, momentum)
|
||||
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
|
||||
|
||||
# Update validation model with exponential moving averages
|
||||
if (i // batch_size % ema_update_freq) == 0:
|
||||
update_ema(train_model, valid_model, ema_rho)
|
||||
|
||||
# Add training time
|
||||
train_time += time.perf_counter() - start_time
|
||||
|
||||
valid_correct = []
|
||||
for i in range(0, len(valid_data), batch_size):
|
||||
valid_model.train(False)
|
||||
|
||||
# Test time agumentation: Test model on regular and flipped data
|
||||
regular_inputs = valid_data[i: i + batch_size]
|
||||
logits = valid_model(regular_inputs).detach()
|
||||
|
||||
if use_TTA:
|
||||
flipped_inputs = torch.flip(regular_inputs, [-1])
|
||||
logits2 = valid_model(flipped_inputs).detach()
|
||||
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
|
||||
|
||||
# Compute correct predictions
|
||||
correct = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
|
||||
valid_correct.append(correct.detach().type(torch.float64))
|
||||
|
||||
# Accuracy is average number of correct predictions
|
||||
valid_acc = torch.mean(torch.cat(valid_correct)).item()
|
||||
if valid_acc > best_acc:
|
||||
best_acc = valid_acc
|
||||
best_model = train_model
|
||||
|
||||
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}")
|
||||
|
||||
return best_acc, best_model
|
||||
|
||||
|
||||
def main():
|
||||
args = argparsing()
|
||||
|
||||
model_type = "resnet9"
|
||||
cifar_dataset = args["dataset"]
|
||||
save = args["save"]
|
||||
weight_name = 'weights/%s_%s' % (model_type, cifar_dataset)
|
||||
kwargs = {
|
||||
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
|
||||
"scale_out": 0.125
|
||||
}
|
||||
|
||||
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
# Configurable parameters
|
||||
ema_update_freq = 5
|
||||
params = {
|
||||
"dataset": cifar_dataset,
|
||||
"epochs": args["epochs"],
|
||||
"batch_size": args["batch"],
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 0.256,
|
||||
"weight_decay_bias": 0.004,
|
||||
"ema_update_freq": ema_update_freq,
|
||||
"ema_rho": 0.99 ** ema_update_freq,
|
||||
"kwargs": kwargs,
|
||||
"use_TTA": False
|
||||
}
|
||||
|
||||
log = {
|
||||
"weights": weight_name,
|
||||
"model_type": model_type,
|
||||
"kwargs": kwargs,
|
||||
"params": params
|
||||
}
|
||||
|
||||
nruns = args["nruns"]
|
||||
|
||||
accuracies = []
|
||||
for run in range(nruns):
|
||||
weight_name_seed = weight_name + "_run%d.pt" % run
|
||||
|
||||
best_acc, best_model = train(**params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
seed=run)
|
||||
accuracies.append(best_acc)
|
||||
print("Best Run Accuracy: %1.4f" % best_acc)
|
||||
log["run%s" % run] = best_acc
|
||||
|
||||
if save:
|
||||
print("Saving %s" % weight_name_seed)
|
||||
torch.save(best_model.state_dict(), weight_name_seed)
|
||||
|
||||
mean = sum(accuracies) / len(accuracies)
|
||||
variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies)
|
||||
std = variance ** 0.5
|
||||
print("Accuracy: %1.4f +/- %1.4f" % (mean, std))
|
||||
log["accuracy"] = [mean, std]
|
||||
|
||||
if save:
|
||||
with open("logs/logs_resnet9_%s.json" % cifar_dataset, 'w') as fp:
|
||||
json.dump(log, fp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
354
palisade_he_cnn/training/train_resnetN.py
Normal file
354
palisade_he_cnn/training/train_resnetN.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from optuna_params import get_optuna_params
|
||||
from palisade_he_cnn.training.utils.utils_dataloading import random_crop
|
||||
from palisade_he_cnn.training.utils.utils_kurtosis import *
|
||||
from palisade_he_cnn.training.utils.utils_resnetN import (
|
||||
get_model, update_nesterov, update_ema, label_smoothing_loss
|
||||
)
|
||||
|
||||
# training time augmentation
|
||||
use_TTA = False
|
||||
|
||||
|
||||
class EarlyStopper:
|
||||
def __init__(self, patience=1, min_delta=0):
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.counter = 0
|
||||
self.max_accuracy = 0.0
|
||||
|
||||
def early_stop(self, accuracy):
|
||||
if accuracy > self.max_accuracy:
|
||||
self.max_accuracy = accuracy
|
||||
self.counter = 0
|
||||
elif accuracy <= (self.max_accuracy + self.min_delta):
|
||||
self.counter += 1
|
||||
if self.counter >= self.patience:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def argparsing():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-n', '--nlayers',
|
||||
help='ResNet model depth',
|
||||
type=int,
|
||||
choices=[20, 32, 44, 56, 110],
|
||||
required=True)
|
||||
parser.add_argument('-bs', '--batch',
|
||||
help='Batch size',
|
||||
type=int,
|
||||
required=False,
|
||||
default=256)
|
||||
parser.add_argument('-e', '--epochs',
|
||||
help='Number of epochs',
|
||||
type=int,
|
||||
required=False,
|
||||
default=100)
|
||||
parser.add_argument('-d', '--debug',
|
||||
help='Debugging mode',
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False)
|
||||
parser.add_argument('-c', '--cuda',
|
||||
help='CUDA device number',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0)
|
||||
parser.add_argument('-dataset', '--dataset',
|
||||
help='CIFAR10 or CIFAR100',
|
||||
type=str,
|
||||
choices=['CIFAR10', 'CIFAR100'],
|
||||
required=False,
|
||||
default='CIFAR10')
|
||||
parser.add_argument('-s', '--save',
|
||||
help='Save model and log files',
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True)
|
||||
|
||||
return vars(parser.parse_args())
|
||||
|
||||
|
||||
def train(
|
||||
dataset,
|
||||
epochs,
|
||||
batch_size,
|
||||
lr,
|
||||
lr_bias,
|
||||
momentum,
|
||||
weight_decay,
|
||||
weight_decay_bias,
|
||||
ema_update_freq,
|
||||
ema_rho,
|
||||
device,
|
||||
dtype,
|
||||
model_type,
|
||||
kwargs,
|
||||
seed=0
|
||||
):
|
||||
# Load dataset
|
||||
if dataset == "CIFAR10":
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
|
||||
else:
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
|
||||
|
||||
train_data = torch.cat(
|
||||
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
|
||||
valid_data = torch.cat(
|
||||
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
|
||||
|
||||
N = int(len(train_data) / batch_size) # 50k / 256, now below is organized by epoch #
|
||||
lr_schedule = torch.cat([
|
||||
torch.linspace(0.0, lr, N),
|
||||
# torch.linspace(lr, lr, 2*N),
|
||||
torch.linspace(lr, 1e-4, 3 * N),
|
||||
torch.linspace(1e-4, 1e-4, 50 * N),
|
||||
torch.linspace(1e-5, 1e-5, 25 * N),
|
||||
torch.linspace(1e-6, 1e-6, 25 * N),
|
||||
])
|
||||
lr_schedule_bias = lr_bias * lr_schedule
|
||||
|
||||
kurt_schedule = torch.cat([
|
||||
torch.linspace(0, 0, 10 * N),
|
||||
torch.linspace(0.05, 0.05, 2 * N),
|
||||
torch.linspace(0.1, 0.1, 200 * N),
|
||||
])
|
||||
# Print information about hardware on first run
|
||||
if seed == 0:
|
||||
if device.type == "cuda":
|
||||
print("Device :", torch.cuda.get_device_name(device.index))
|
||||
|
||||
print("Dtype :", dtype)
|
||||
print()
|
||||
|
||||
# Start measuring time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Set random seed to increase chance of reproducability
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Setting cudnn.benchmark to True hampers reproducability, but is faster
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Convert model weights to half precision
|
||||
train_model = get_model(model_type, kwargs).to(device)
|
||||
train_model.to(dtype)
|
||||
|
||||
# Convert BatchNorm back to single precision for better accuracy
|
||||
for module in train_model.modules():
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
module.float()
|
||||
|
||||
# Collect weights and biases and create nesterov velocity values
|
||||
weights = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) > 1
|
||||
]
|
||||
biases = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) <= 1
|
||||
]
|
||||
|
||||
# Copy the model for validation
|
||||
valid_model = copy.deepcopy(train_model)
|
||||
|
||||
# Patience:
|
||||
early_stopper = EarlyStopper(patience=120, min_delta=0.001) # this is %
|
||||
|
||||
# Testing non-SGD optimizer
|
||||
optimizer = torch.optim.Adam(train_model.parameters(), lr=0.001)
|
||||
|
||||
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
|
||||
print("\nepoch batch train time [sec] validation accuracy")
|
||||
|
||||
train_time = 0.0
|
||||
batch_count = 0
|
||||
best_acc = 0.0
|
||||
best_model = None
|
||||
for epoch in range(1, epochs + 1):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Randomly shuffle training data
|
||||
indices = torch.randperm(len(train_data), device=device)
|
||||
data = train_data[indices]
|
||||
targets = train_targets[indices]
|
||||
|
||||
# Crop random 32x32 patches from 40x40 training data
|
||||
data = [
|
||||
random_crop(data[i: i + batch_size], crop_size=(32, 32))
|
||||
for i in range(0, len(data), batch_size)
|
||||
]
|
||||
data = torch.cat(data)
|
||||
|
||||
# Randomly flip half the training data
|
||||
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
|
||||
|
||||
for i in range(0, len(data), batch_size):
|
||||
# Discard partial batches
|
||||
if i + batch_size > len(data):
|
||||
break
|
||||
|
||||
# Slice batch from data
|
||||
inputs = data[i: i + batch_size]
|
||||
target = targets[i: i + batch_size]
|
||||
batch_count += 1
|
||||
|
||||
# Compute new gradients
|
||||
train_model.zero_grad()
|
||||
train_model.train(True)
|
||||
|
||||
# kurtosis setup
|
||||
remove_all_hooks(train_model)
|
||||
activations = get_intermediate_output(train_model)
|
||||
|
||||
logits = train_model(inputs)
|
||||
loss = label_smoothing_loss(logits, target, alpha=0.2)
|
||||
|
||||
# kurtosis scheduler
|
||||
kurt_index = min(batch_count, len(kurt_schedule) - 1)
|
||||
kurt_scale = kurt_schedule[kurt_index]
|
||||
|
||||
# kurtosis calculation and cleanup
|
||||
remove_all_hooks(train_model)
|
||||
activations = {k: torch.cat(v) for k, v in activations.items()}
|
||||
loss_means, loss_stds, loss_kurts = get_statistics(activations)
|
||||
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
|
||||
|
||||
loss.sum().backward()
|
||||
|
||||
lr_index = min(batch_count, len(lr_schedule) - 1)
|
||||
lr = lr_schedule[lr_index]
|
||||
lr_bias = lr_schedule_bias[lr_index]
|
||||
|
||||
# Update weights and biases of training model
|
||||
update_nesterov(weights, lr, weight_decay, momentum)
|
||||
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
|
||||
|
||||
# Update validation model with exponential moving averages
|
||||
if (i // batch_size % ema_update_freq) == 0:
|
||||
update_ema(train_model, valid_model, ema_rho)
|
||||
|
||||
# Add training time
|
||||
train_time += time.perf_counter() - start_time
|
||||
|
||||
valid_correct = []
|
||||
for i in range(0, len(valid_data), batch_size):
|
||||
valid_model.train(False)
|
||||
regular_inputs = valid_data[i: i + batch_size]
|
||||
logits = valid_model(regular_inputs).detach()
|
||||
|
||||
if use_TTA:
|
||||
flipped_inputs = torch.flip(regular_inputs, [-1])
|
||||
logits2 = valid_model(flipped_inputs).detach()
|
||||
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
|
||||
|
||||
# Compute correct predictions
|
||||
correct = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
|
||||
|
||||
valid_correct.append(correct.detach().type(torch.float64))
|
||||
|
||||
# Accuracy is average number of correct predictions
|
||||
valid_acc = torch.mean(torch.cat(valid_correct)).item()
|
||||
if valid_acc > best_acc:
|
||||
best_acc = valid_acc
|
||||
best_model = train_model
|
||||
|
||||
if early_stopper.early_stop(valid_acc):
|
||||
print("Early stopping")
|
||||
break
|
||||
|
||||
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {valid_acc:22.4f}")
|
||||
|
||||
return best_acc, best_model
|
||||
|
||||
|
||||
def main():
|
||||
args = argparsing()
|
||||
|
||||
model_type = "resnet%s" % args["nlayers"]
|
||||
cifar_dataset = args["dataset"]
|
||||
save = args["save"]
|
||||
weight_name = 'weights/%s_%s' % (model_type, cifar_dataset)
|
||||
|
||||
print("ResNet%s" % args["nlayers"])
|
||||
print("Weight file:", weight_name)
|
||||
|
||||
kwargs = {
|
||||
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
|
||||
"debug": args["debug"]
|
||||
}
|
||||
|
||||
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
# Optuna:
|
||||
optuna_params = get_optuna_params(model_type, cifar_dataset)
|
||||
lr = optuna_params["lr"]
|
||||
lr_bias = optuna_params["lr_bias"]
|
||||
momentum = optuna_params["momentum"]
|
||||
weight_decay = optuna_params["weight_decay"]
|
||||
|
||||
# Configurable parameters
|
||||
ema_update_freq = 5
|
||||
params = {
|
||||
"dataset": cifar_dataset,
|
||||
"epochs": args["epochs"],
|
||||
"batch_size": args["batch"],
|
||||
"lr": lr,
|
||||
"lr_bias": lr_bias,
|
||||
"momentum": momentum,
|
||||
"weight_decay": weight_decay,
|
||||
"weight_decay_bias": 0.004,
|
||||
"ema_update_freq": ema_update_freq,
|
||||
"ema_rho": 0.99 ** ema_update_freq,
|
||||
"model_type": model_type,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
|
||||
nruns = 5
|
||||
log = {
|
||||
"weights": weight_name,
|
||||
"model_type": model_type,
|
||||
"kwargs": kwargs,
|
||||
"params": params
|
||||
}
|
||||
accuracies = []
|
||||
for run in range(nruns):
|
||||
weight_name_seed = weight_name + "_run%d.pt" % run
|
||||
|
||||
best_acc, best_model = train(**params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
seed=run)
|
||||
accuracies.append(best_acc)
|
||||
print("Best Run Accuracy: %1.4f" % best_acc)
|
||||
log["run%s" % run] = best_acc
|
||||
|
||||
if save:
|
||||
print("Saving %s" % weight_name_seed)
|
||||
torch.save(best_model.state_dict(), weight_name_seed)
|
||||
|
||||
mean = sum(accuracies) / len(accuracies)
|
||||
variance = sum((acc - mean) ** 2 for acc in accuracies) / len(accuracies)
|
||||
std = variance ** 0.5
|
||||
print("Accuracy: %1.4f +/- %1.4f" % (mean, std))
|
||||
log["accuracy"] = [mean, std]
|
||||
|
||||
if save:
|
||||
with open("logs/logs_resnet%s_%s.json" % (args["nlayers"], cifar_dataset), 'w') as fp:
|
||||
json.dump(log, fp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
326
palisade_he_cnn/training/train_resnetN_optuna.py
Normal file
326
palisade_he_cnn/training/train_resnetN_optuna.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import time
|
||||
|
||||
import optuna
|
||||
import joblib
|
||||
from optuna.trial import TrialState
|
||||
|
||||
from palisade_he_cnn.training.utils.utils_dataloading import *
|
||||
from palisade_he_cnn.training.utils.utils_kurtosis import *
|
||||
from palisade_he_cnn.training.utils.utils_resnetN import (
|
||||
get_model, update_nesterov, update_ema, label_smoothing_loss
|
||||
)
|
||||
|
||||
use_TTA = False
|
||||
|
||||
|
||||
def argparsing():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-n', '--nlayers',
|
||||
help='ResNet model depth',
|
||||
type=int,
|
||||
choices=[20, 32, 44, 56, 110],
|
||||
required=True)
|
||||
parser.add_argument('-bs', '--batch',
|
||||
help='Batch size',
|
||||
type=int,
|
||||
required=False,
|
||||
default=256)
|
||||
parser.add_argument('-e', '--epochs',
|
||||
help='Number of epochs',
|
||||
type=int,
|
||||
required=False,
|
||||
default=100)
|
||||
parser.add_argument('-d', '--debug',
|
||||
help='Debugging mode',
|
||||
type=bool,
|
||||
required=False,
|
||||
default=False)
|
||||
parser.add_argument('-c', '--cuda',
|
||||
help='CUDA device number',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0)
|
||||
parser.add_argument('-dataset', '--dataset',
|
||||
help='CIFAR10 or CIFAR100',
|
||||
type=str,
|
||||
choices=['CIFAR10', 'CIFAR100'],
|
||||
required=False,
|
||||
default='CIFAR10')
|
||||
parser.add_argument('-s', '--save',
|
||||
help='Save model and log files',
|
||||
type=bool,
|
||||
required=False,
|
||||
default=True)
|
||||
|
||||
return vars(parser.parse_args())
|
||||
|
||||
|
||||
def train(
|
||||
trial,
|
||||
dataset,
|
||||
epochs,
|
||||
batch_size,
|
||||
weight_decay_bias,
|
||||
ema_update_freq,
|
||||
ema_rho,
|
||||
device,
|
||||
dtype,
|
||||
model_type,
|
||||
kwargs,
|
||||
seed=0
|
||||
):
|
||||
# Print information about hardware on first run
|
||||
if seed == 0:
|
||||
if device.type == "cuda":
|
||||
print("Device :", torch.cuda.get_device_name(device.index))
|
||||
|
||||
print("Dtype :", dtype)
|
||||
print()
|
||||
|
||||
# Start measuring time
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Set random seed to increase chance of reproducability
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Load dataset
|
||||
if dataset == "CIFAR10":
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar10(device, dtype)
|
||||
else:
|
||||
train_data, train_targets, valid_data, valid_targets = load_cifar100(device, dtype)
|
||||
|
||||
train_data = torch.cat(
|
||||
[train_data, torch.zeros(train_data.size(0), 1, train_data.size(2), train_data.size(3)).to(device)], dim=1)
|
||||
valid_data = torch.cat(
|
||||
[valid_data, torch.zeros(valid_data.size(0), 1, valid_data.size(2), valid_data.size(3)).to(device)], dim=1)
|
||||
|
||||
# Convert model weights to half precision
|
||||
train_model = get_model(model_type, kwargs).to(device)
|
||||
train_model.to(dtype)
|
||||
train_model.train()
|
||||
|
||||
# Generate the optimizers.
|
||||
lr = trial.suggest_float("lr", 9e-4, 2e-3)
|
||||
lr_bias = trial.suggest_float("lr_bias", 54, 74)
|
||||
momentum = trial.suggest_float("momentum", 0.7, .99)
|
||||
weight_decay = trial.suggest_float("weight_decay", 0.01, .3)
|
||||
|
||||
N = int(len(train_data) / batch_size)
|
||||
lr_schedule = torch.cat([
|
||||
torch.linspace(0.0, lr, N),
|
||||
# torch.linspace(lr, lr, 2*N),
|
||||
torch.linspace(lr, 1e-4, 3 * N),
|
||||
torch.linspace(1e-4, 1e-4, 50 * N),
|
||||
torch.linspace(1e-5, 1e-5, 25 * N),
|
||||
torch.linspace(1e-6, 1e-6, 25 * N),
|
||||
])
|
||||
lr_schedule_bias = lr_bias * lr_schedule
|
||||
|
||||
kurt_schedule = torch.cat([
|
||||
torch.linspace(0, 0, 10 * N),
|
||||
torch.linspace(0.05, 0.05, 2 * N),
|
||||
torch.linspace(0.1, 0.1, 200 * N),
|
||||
])
|
||||
|
||||
# Convert BatchNorm back to single precision for better accuracy
|
||||
for module in train_model.modules():
|
||||
if isinstance(module, nn.BatchNorm2d):
|
||||
module.float()
|
||||
|
||||
# Collect weights and biases and create nesterov velocity values
|
||||
weights = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) > 1
|
||||
]
|
||||
biases = [
|
||||
(w, torch.zeros_like(w))
|
||||
for w in train_model.parameters()
|
||||
if w.requires_grad and len(w.shape) <= 1
|
||||
]
|
||||
|
||||
# Copy the model for validation
|
||||
valid_model = copy.deepcopy(train_model)
|
||||
|
||||
print(f"Preprocessing: {time.perf_counter() - start_time:.2f} seconds")
|
||||
|
||||
# Train and validate
|
||||
print("\nepoch batch train time [sec] validation accuracy")
|
||||
train_time = 0.0
|
||||
batch_count = 0
|
||||
best_acc = []
|
||||
for epoch in range(1, epochs + 1):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Randomly shuffle training data
|
||||
indices = torch.randperm(len(train_data), device=device)
|
||||
data = train_data[indices]
|
||||
targets = train_targets[indices]
|
||||
|
||||
# Crop random 32x32 patches from 40x40 training data
|
||||
data = [
|
||||
random_crop(data[i: i + batch_size], crop_size=(32, 32))
|
||||
for i in range(0, len(data), batch_size)
|
||||
]
|
||||
data = torch.cat(data)
|
||||
|
||||
# Randomly flip half the training data
|
||||
data[: len(data) // 2] = torch.flip(data[: len(data) // 2], [-1])
|
||||
loss_epoch = 0.0
|
||||
for i in range(0, len(data), batch_size):
|
||||
# discard partial batches
|
||||
if i + batch_size > len(data):
|
||||
break
|
||||
|
||||
# Slice batch from data
|
||||
inputs = data[i: i + batch_size]
|
||||
target = targets[i: i + batch_size]
|
||||
batch_count += 1
|
||||
|
||||
# Compute new gradients
|
||||
train_model.zero_grad()
|
||||
train_model.train(True)
|
||||
|
||||
# kurtosis setup
|
||||
remove_all_hooks(train_model)
|
||||
activations = get_intermediate_output(train_model)
|
||||
|
||||
logits = train_model(inputs)
|
||||
loss = label_smoothing_loss(logits, target, alpha=0.2)
|
||||
|
||||
# kurtosis scheduler
|
||||
kurt_index = min(batch_count, len(kurt_schedule) - 1)
|
||||
kurt_scale = kurt_schedule[kurt_index]
|
||||
|
||||
# kurtosis calculation and cleanup
|
||||
remove_all_hooks(train_model)
|
||||
activations = {k: torch.cat(v) for k, v in activations.items()}
|
||||
loss_means, loss_stds, loss_kurts = get_statistics(activations)
|
||||
loss += (loss_means + loss_stds + loss_kurts) * kurt_scale
|
||||
|
||||
loss.sum().backward()
|
||||
|
||||
loss_epoch += loss.sum().item() / batch_size
|
||||
|
||||
lr_index = min(batch_count, len(lr_schedule) - 1)
|
||||
lr = lr_schedule[lr_index]
|
||||
lr_bias = lr_schedule_bias[lr_index]
|
||||
|
||||
# Update weights and biases of training model
|
||||
update_nesterov(weights, lr, weight_decay, momentum)
|
||||
update_nesterov(biases, lr_bias, weight_decay_bias, momentum)
|
||||
|
||||
# Update validation model with exponential moving averages
|
||||
if (i // batch_size % ema_update_freq) == 0:
|
||||
update_ema(train_model, valid_model, ema_rho)
|
||||
|
||||
# Add training time
|
||||
train_time += time.perf_counter() - start_time
|
||||
|
||||
correct = []
|
||||
with torch.no_grad():
|
||||
for i in range(0, len(valid_data), batch_size):
|
||||
valid_model.train(False)
|
||||
regular_inputs = valid_data[i: i + batch_size]
|
||||
logits = valid_model(regular_inputs).detach()
|
||||
|
||||
if use_TTA:
|
||||
flipped_inputs = torch.flip(regular_inputs, [-1])
|
||||
logits2 = valid_model(flipped_inputs).detach()
|
||||
logits = torch.mean(torch.stack([logits, logits2], dim=0), dim=0)
|
||||
|
||||
# Compute correct predictions
|
||||
temp = logits.max(dim=1)[1] == valid_targets[i: i + batch_size]
|
||||
|
||||
correct.append(temp.detach().type(torch.float64))
|
||||
|
||||
# Accuracy is average number of correct predictions
|
||||
accuracy = torch.mean(torch.cat(correct)).item()
|
||||
best_acc.append(accuracy)
|
||||
print(f"{epoch:5} {batch_count:8d} {train_time:19.2f} {accuracy:22.4f}")
|
||||
|
||||
trial.report(accuracy, epoch)
|
||||
|
||||
# Handle pruning based on the intermediate value.
|
||||
if trial.should_prune():
|
||||
raise optuna.exceptions.TrialPruned()
|
||||
|
||||
return max(best_acc)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
args = argparsing()
|
||||
|
||||
model_type = "resnet%s" % args["nlayers"]
|
||||
cifar_dataset = args["dataset"]
|
||||
save = args["save"]
|
||||
weight_name = 'weights/optuna/%s_%s' % (model_type, cifar_dataset)
|
||||
|
||||
print("ResNet%s" % args["nlayers"])
|
||||
print("Weight file:", weight_name)
|
||||
|
||||
kwargs = {
|
||||
"num_classes": 10 if cifar_dataset == 'CIFAR10' else 100,
|
||||
"debug": args["debug"]
|
||||
}
|
||||
|
||||
device = torch.device("cuda:%s" % args["cuda"] if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
ema_update_freq = 5
|
||||
params = {
|
||||
"trial": trial,
|
||||
"dataset": cifar_dataset,
|
||||
"epochs": args["epochs"],
|
||||
"batch_size": args["batch"],
|
||||
"weight_decay_bias": 0.004,
|
||||
"ema_update_freq": ema_update_freq,
|
||||
"ema_rho": 0.99 ** ema_update_freq,
|
||||
"model_type": model_type,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
accuracy = train(**params,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
seed=0)
|
||||
return accuracy
|
||||
|
||||
|
||||
def main():
|
||||
args = argparsing()
|
||||
model_type = "resnet%s" % args["nlayers"]
|
||||
cifar_dataset = args["dataset"]
|
||||
save = args["save"]
|
||||
|
||||
study = optuna.create_study(direction="maximize",
|
||||
storage="sqlite:///db.sqlite3",
|
||||
pruner=optuna.pruners.PatientPruner(optuna.pruners.MedianPruner(), patience=5,
|
||||
min_delta=0.0))
|
||||
study.optimize(objective, n_trials=10)
|
||||
|
||||
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
|
||||
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
|
||||
|
||||
print("Study statistics: ")
|
||||
print(" Number of finished trials: ", len(study.trials))
|
||||
print(" Number of pruned trials: ", len(pruned_trials))
|
||||
print(" Number of complete trials: ", len(complete_trials))
|
||||
|
||||
print("Best trial:")
|
||||
trial = study.best_trial
|
||||
|
||||
print(" Value: ", trial.value)
|
||||
|
||||
print(" Params: ")
|
||||
for key, value in trial.params.items():
|
||||
print(" {}: {}".format(key, value))
|
||||
|
||||
joblib.dump(study, "study_%s_%s.pkl" % (model_type, cifar_dataset))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
76
palisade_he_cnn/training/train_resnets.py
Normal file
76
palisade_he_cnn/training/train_resnets.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
#layer = [20,32,44,56,110]
|
||||
#batch = [256,256,256,256,64]
|
||||
#epoch = [200,200,200,200,200]
|
||||
|
||||
#layer = [32]
|
||||
#batch = [256]
|
||||
#epoch = [200]
|
||||
|
||||
def argparsing():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c','--cuda',
|
||||
help='CUDA device number',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0)
|
||||
parser.add_argument('-dataset', '--dataset',
|
||||
help='CIFAR10 or CIFAR100',
|
||||
type=str,
|
||||
choices=['CIFAR10','CIFAR100'],
|
||||
required=False,
|
||||
default='CIFAR10')
|
||||
parser.add_argument('-n', '--nlayers',
|
||||
help='List of layers in string format',
|
||||
type=str,
|
||||
required=True,
|
||||
default='[20,32,44,56,110]')
|
||||
parser.add_argument('-e', '--epochs',
|
||||
help='List of epochs in string format',
|
||||
type=str,
|
||||
required=True,
|
||||
default='[100,100,100,100,100]')
|
||||
parser.add_argument('-bs', '--batch_size',
|
||||
help='List of batch sizes in string format',
|
||||
type=str,
|
||||
required=True,
|
||||
default='[256,256,256,256,64]')
|
||||
return vars(parser.parse_args())
|
||||
|
||||
def str2list(arg):
|
||||
return [int(item) for item in arg.split(',') if item!='']
|
||||
|
||||
def main():
|
||||
|
||||
args = argparsing()
|
||||
layer = str2list(args["nlayers"])
|
||||
batch = str2list(args["epochs"])
|
||||
epoch = str2list(args["batch_size"])
|
||||
|
||||
dataset = args["dataset"]
|
||||
cuda = args["cuda"]
|
||||
|
||||
for i in range(len(layer)):
|
||||
cmd = "python3 train_resnetN.py -n %s -bs %s -e %s -dataset %s -c %d" \
|
||||
% (layer[i], batch[i], epoch[i], dataset, cuda)
|
||||
|
||||
print("\n")
|
||||
print(cmd)
|
||||
subprocess.run(
|
||||
[
|
||||
"python3",
|
||||
"train_resnetN.py",
|
||||
"-n", "%s"%str(layer[i]),
|
||||
"-bs", "%s"%str(batch[i]),
|
||||
"-e", "%s"%str(epoch[i]),
|
||||
"-c", "%s"%str(cuda),
|
||||
"-dataset", dataset
|
||||
],
|
||||
shell=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
52
palisade_he_cnn/training/train_resnets_optuna.py
Normal file
52
palisade_he_cnn/training/train_resnets_optuna.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import subprocess
|
||||
import argparse
|
||||
|
||||
layer = [20, 32, 44, 56, 110]
|
||||
batch = [256, 256, 256, 256, 64]
|
||||
epoch = [50, 50, 50, 50, 50]
|
||||
|
||||
|
||||
def argparsing():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-c', '--cuda',
|
||||
help='CUDA device number',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0)
|
||||
parser.add_argument('-dataset', '--dataset',
|
||||
help='CIFAR10 or CIFAR100',
|
||||
type=str,
|
||||
choices=['CIFAR10', 'CIFAR100'],
|
||||
required=False,
|
||||
default='CIFAR10')
|
||||
return vars(parser.parse_args())
|
||||
|
||||
|
||||
def main():
|
||||
args = argparsing()
|
||||
dataset = args["dataset"]
|
||||
cuda = args["cuda"]
|
||||
|
||||
for i in range(len(layer)):
|
||||
cmd = "python3 train_resnetN_optuna.py -n %s -bs %s -e %s -dataset %s -c %d" \
|
||||
% (layer[i], batch[i], epoch[i], dataset, cuda)
|
||||
|
||||
print("\n")
|
||||
print(cmd)
|
||||
subprocess.run(
|
||||
[
|
||||
"python3",
|
||||
"train_resnetN_optuna.py",
|
||||
"-n", "%s" % str(layer[i]),
|
||||
"-bs", "%s" % str(batch[i]),
|
||||
"-e", "%s" % str(epoch[i]),
|
||||
"-c", "%s" % str(cuda),
|
||||
"-dataset", dataset
|
||||
],
|
||||
shell=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
palisade_he_cnn/training/utils/utils.py
Normal file
146
palisade_he_cnn/training/utils/utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import ast
|
||||
from collections import OrderedDict, defaultdict
|
||||
from typing import Dict, Callable
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tt
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
|
||||
class PadChannel(object):
|
||||
def __init__(self, npad: int=1):
|
||||
self.n = npad
|
||||
|
||||
def __call__(self, x):
|
||||
_, width, height = x.shape
|
||||
x = torch.cat([x, torch.zeros(self.n, width, height)])
|
||||
return x
|
||||
|
||||
def get_gelu_poly_coeffs(degree, filename='gelu_poly_approx_params.txt'):
|
||||
with open(filename, 'r') as fp:
|
||||
params = []
|
||||
for line in fp:
|
||||
x = line[:-1]
|
||||
params.append(ast.literal_eval(x))
|
||||
|
||||
if degree==2:
|
||||
return params[0]
|
||||
if degree==4:
|
||||
return params[1]
|
||||
elif degree==8:
|
||||
return params[2]
|
||||
elif degree==16:
|
||||
return params[3]
|
||||
elif degree==32:
|
||||
return params[4]
|
||||
else:
|
||||
print("Defaulting to deg8")
|
||||
return params[2]
|
||||
|
||||
def patch_whitening(data, patch_size=(3, 3)):
|
||||
# Compute weights from data such that
|
||||
# torch.std(F.conv2d(data, weights), dim=(2, 3))
|
||||
# is close to 1.
|
||||
h, w = patch_size
|
||||
c = data.size(1)
|
||||
patches = data.unfold(2, h, 1).unfold(3, w, 1)
|
||||
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
|
||||
|
||||
n, c, h, w = patches.shape
|
||||
X = patches.reshape(n, c * h * w)
|
||||
X = X / (X.size(0) - 1) ** 0.5
|
||||
covariance = X.t() @ X
|
||||
|
||||
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
|
||||
|
||||
eigenvalues = eigenvalues.flip(0)
|
||||
|
||||
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
|
||||
|
||||
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
|
||||
|
||||
|
||||
def get_cifar10_dataloader(batch_size,
|
||||
data_dir: str='../../datasets/cifar10/',
|
||||
num_workers: int=4):
|
||||
stats = ((0.4914, 0.4822, 0.4465),
|
||||
(0.2023, 0.1994, 0.2010))
|
||||
|
||||
train_tfms = tt.Compose([
|
||||
tt.RandomCrop(32,padding=4,padding_mode='reflect'),
|
||||
tt.RandomHorizontalFlip(),
|
||||
tt.ToTensor(),
|
||||
tt.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1)
|
||||
])
|
||||
|
||||
val_tfms = tt.Compose([
|
||||
tt.ToTensor(),
|
||||
tt.Normalize(*stats,inplace=True),
|
||||
PadChannel(npad=1)
|
||||
])
|
||||
|
||||
train_ds = ImageFolder(data_dir+'train',transform=train_tfms)
|
||||
val_ds = ImageFolder(data_dir+'test',transform=val_tfms)
|
||||
|
||||
train_dl = DataLoader(train_ds,
|
||||
batch_size,
|
||||
pin_memory = True,
|
||||
num_workers = num_workers,
|
||||
shuffle = True)
|
||||
val_dl = DataLoader(val_ds,
|
||||
batch_size,
|
||||
pin_memory = True,
|
||||
num_workers = num_workers)
|
||||
return train_dl, val_dl
|
||||
|
||||
|
||||
def remove_all_hooks(model: torch.nn.Module) -> None:
|
||||
for name, child in model._modules.items():
|
||||
if child is not None:
|
||||
if hasattr(child, "_forward_hooks"):
|
||||
child._forward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_forward_pre_hooks"):
|
||||
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_backward_hooks"):
|
||||
child._backward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
remove_all_hooks(child)
|
||||
|
||||
# Given a model and an input, get intermediate layer output
|
||||
def get_intermediate_output(model):
|
||||
activation = defaultdict(list)
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
x = output.detach().cpu()
|
||||
activation[name].append(x)
|
||||
|
||||
return hook
|
||||
|
||||
BatchNorm_layers = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||||
for i, b in enumerate(BatchNorm_layers):
|
||||
b.register_forward_hook(
|
||||
get_activation(f"bn_{i + 1}")
|
||||
)
|
||||
return activation
|
||||
|
||||
|
||||
def get_all_bn_activations(model, val_dl, DEVICE):
|
||||
activation = get_intermediate_output(model)
|
||||
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
for img, label in (val_dl):
|
||||
img, label = img.to(DEVICE), label.to(DEVICE)
|
||||
out = model(img)
|
||||
|
||||
remove_all_hooks(model)
|
||||
|
||||
activation = {k:torch.cat(v) for k,v in activation.items()}
|
||||
|
||||
return activation
|
||||
|
||||
75
palisade_he_cnn/training/utils/utils_dataloading.py
Normal file
75
palisade_he_cnn/training/utils/utils_dataloading.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
|
||||
def load_cifar10(device, dtype, data_dir='./datasets/cifar10/'):
|
||||
print("Loading CIFAR10")
|
||||
train = torchvision.datasets.CIFAR10(root=data_dir, download=True)
|
||||
valid = torchvision.datasets.CIFAR10(root=data_dir, train=False)
|
||||
|
||||
train_data = preprocess_cifar10_data(train.data, device, dtype)
|
||||
valid_data = preprocess_cifar10_data(valid.data, device, dtype)
|
||||
|
||||
train_targets = torch.tensor(train.targets).to(device)
|
||||
valid_targets = torch.tensor(valid.targets).to(device)
|
||||
|
||||
# Pad 32x32 to 40x40
|
||||
train_data = nn.ReflectionPad2d(4)(train_data)
|
||||
|
||||
return train_data, train_targets, valid_data, valid_targets
|
||||
|
||||
def load_cifar100(device, dtype, data_dir='./datasets/cifar100/'):
|
||||
print("Loading CIFAR100")
|
||||
train = torchvision.datasets.CIFAR100(root=data_dir, download=True)
|
||||
valid = torchvision.datasets.CIFAR100(root=data_dir, train=False)
|
||||
|
||||
train_data = preprocess_cifar100_data(train.data, device, dtype)
|
||||
valid_data = preprocess_cifar100_data(valid.data, device, dtype)
|
||||
|
||||
train_targets = torch.tensor(train.targets).to(device)
|
||||
valid_targets = torch.tensor(valid.targets).to(device)
|
||||
|
||||
# Pad 32x32 to 40x40
|
||||
train_data = nn.ReflectionPad2d(4)(train_data)
|
||||
|
||||
return train_data, train_targets, valid_data, valid_targets
|
||||
|
||||
def random_crop(data, crop_size):
|
||||
crop_h, crop_w = crop_size
|
||||
h = data.size(2)
|
||||
w = data.size(3)
|
||||
x = torch.randint(w - crop_w, size=(1,))[0]
|
||||
y = torch.randint(h - crop_h, size=(1,))[0]
|
||||
return data[:, :, y : y + crop_h, x : x + crop_w]
|
||||
|
||||
def preprocess_cifar10_data(data, device, dtype):
|
||||
# Convert to torch float16 tensor
|
||||
data = torch.tensor(data, device=device).to(dtype)
|
||||
|
||||
# Normalize
|
||||
mean = torch.tensor([125.31, 122.95, 113.87], device=device).to(dtype)
|
||||
std = torch.tensor([62.99, 62.09, 66.70], device=device).to(dtype)
|
||||
data = (data - mean) / std
|
||||
|
||||
# Permute data from NHWC to NCHW format
|
||||
data = data.permute(0, 3, 1, 2)
|
||||
|
||||
return data
|
||||
|
||||
def preprocess_cifar100_data(data, device, dtype):
|
||||
# Convert to torch float16 tensor
|
||||
data = torch.tensor(data, device=device).to(dtype)
|
||||
|
||||
# Normalize
|
||||
mean = torch.tensor([129.30, 124.07, 112.43], device=device).to(dtype)
|
||||
std = torch.tensor([68.17, 65.39, 70.42], device=device).to(dtype)
|
||||
data = (data - mean) / std
|
||||
|
||||
# Permute data from NHWC to NCHW format
|
||||
data = data.permute(0, 3, 1, 2)
|
||||
|
||||
return data
|
||||
|
||||
63
palisade_he_cnn/training/utils/utils_kurtosis.py
Normal file
63
palisade_he_cnn/training/utils/utils_kurtosis.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
from collections import defaultdict, OrderedDict
|
||||
from typing import Dict, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def get_intermediate_output(model):
|
||||
activations = defaultdict(list)
|
||||
|
||||
def get_activation(name):
|
||||
def hook(model, input, output):
|
||||
x = input[0]
|
||||
activations[name].append(x)
|
||||
|
||||
return hook
|
||||
|
||||
GELU_layers = [m for m in model.modules() if isinstance(m, torch.nn.GELU)]
|
||||
for i, b in enumerate(GELU_layers):
|
||||
b.register_forward_hook(
|
||||
get_activation(f"GELU_{i + 1}")
|
||||
)
|
||||
return activations
|
||||
|
||||
def remove_all_hooks(model: torch.nn.Module) -> None:
|
||||
for name, child in model._modules.items():
|
||||
if child is not None:
|
||||
if hasattr(child, "_forward_hooks"):
|
||||
child._forward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_forward_pre_hooks"):
|
||||
child._forward_pre_hooks: Dict[int, Callable] = OrderedDict()
|
||||
elif hasattr(child, "_backward_hooks"):
|
||||
child._backward_hooks: Dict[int, Callable] = OrderedDict()
|
||||
remove_all_hooks(child)
|
||||
|
||||
def moment(x: torch.Tensor, std: float, mean: float, deg: int=4, eps: float=1e-4) -> torch.Tensor:
|
||||
x = x.double()
|
||||
temp = (x-mean)**deg / x.shape[0]
|
||||
return torch.sum(temp) / (std**deg + eps)
|
||||
|
||||
def get_statistics(activations):
|
||||
n = len(activations)
|
||||
means = torch.zeros(n)
|
||||
stds = torch.zeros(n)
|
||||
kurts = torch.zeros(n)
|
||||
|
||||
for layer_index,name in enumerate(sorted(activations.keys(), key=lambda x:int(x.split('_')[1]))):
|
||||
dist = activations[name]
|
||||
dist = dist.flatten()
|
||||
|
||||
std, mean = torch.std_mean(dist)
|
||||
kurt = moment(dist, std, mean, deg=4)
|
||||
|
||||
means[layer_index] = mean
|
||||
stds[layer_index] = std
|
||||
kurts[layer_index] = kurt
|
||||
|
||||
loss_means = F.mse_loss(means, torch.zeros(n))
|
||||
loss_stds = F.mse_loss(stds, torch.ones(n))
|
||||
loss_kurts = F.mse_loss(kurts, 3*torch.ones(n))
|
||||
|
||||
return loss_means, loss_stds, loss_kurts
|
||||
95
palisade_he_cnn/training/utils/utils_resnetN.py
Normal file
95
palisade_he_cnn/training/utils/utils_resnetN.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# (c) 2021-2024 The Johns Hopkins University Applied Physics Laboratory LLC (JHU/APL).
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
import glob
|
||||
from palisade_he_cnn.training.models.resnetN_multiplexed import *
|
||||
|
||||
def get_model(model_type, kwargs):
|
||||
if model_type=='resnet20':
|
||||
return resnet20(**kwargs)
|
||||
elif model_type=='resnet32':
|
||||
return resnet32(**kwargs)
|
||||
elif model_type=='resnet44':
|
||||
return resnet44(**kwargs)
|
||||
elif model_type=='resnet56':
|
||||
return resnet56(**kwargs)
|
||||
elif model_type=='resnet110':
|
||||
return resnet110(**kwargs)
|
||||
elif model_type=='resnet_test':
|
||||
return resnet_test(**kwargs)
|
||||
else:
|
||||
raise ValueError("Returning None bc you are wrong!")
|
||||
|
||||
def get_best_weights(loc, dataset, model_type):
|
||||
loc = '%s%s/' % (loc,dataset)
|
||||
log_file = None
|
||||
for log in glob.glob(loc+'logs/*.json'):
|
||||
if model_type in log:
|
||||
log_file = log
|
||||
break
|
||||
|
||||
if log_file is None:
|
||||
raise ValueError("model_type number must be resnet9,20,32,44,56, or 110")
|
||||
|
||||
with open(log_file) as f:
|
||||
contents = json.load(f)
|
||||
|
||||
print("Finding the best model according to logs...")
|
||||
print(contents)
|
||||
|
||||
runs = {"run%d"%i : contents["run%d"%i] for i in range(5)}
|
||||
mean, std = contents["accuracy"]
|
||||
accs = [contents["run%d"%i] for i in range(5)]
|
||||
idx = accs.index(max(accs))
|
||||
|
||||
print("\nAverage (5 runs): %1.3f%% +/- %1.3f%%" % (100*mean, 100*std))
|
||||
print("Best (idx %d): %1.3f" % (idx,runs["run%d"%idx]))
|
||||
weight_file = loc + "weights/%s_%s_run%d.pt" % (model_type, dataset, idx)
|
||||
return weight_file
|
||||
|
||||
def num_params(model) -> int:
|
||||
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
||||
return sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
def update_ema(train_model, valid_model, rho):
|
||||
# The trained model is not used for validation directly. Instead, the
|
||||
# validation model weights are updated with exponential moving averages.
|
||||
train_weights = train_model.state_dict().values()
|
||||
valid_weights = valid_model.state_dict().values()
|
||||
for train_weight, valid_weight in zip(train_weights, valid_weights):
|
||||
if valid_weight.dtype in [torch.float16, torch.float32]:
|
||||
valid_weight *= rho
|
||||
valid_weight += (1 - rho) * train_weight
|
||||
|
||||
def update_nesterov(weights, lr, weight_decay, momentum):
|
||||
for weight, velocity in weights:
|
||||
if weight.requires_grad:
|
||||
gradient = weight.grad.data
|
||||
weight = weight.data
|
||||
|
||||
gradient.add_(weight, alpha=weight_decay).mul_(-lr)
|
||||
velocity.mul_(momentum).add_(gradient)
|
||||
weight.add_(gradient.add_(velocity, alpha=momentum))
|
||||
|
||||
def label_smoothing_loss(inputs, targets, alpha):
|
||||
log_probs = torch.nn.functional.log_softmax(inputs, dim=1, _stacklevel=5)
|
||||
kl = -log_probs.mean(dim=1)
|
||||
xent = torch.nn.functional.nll_loss(log_probs, targets, reduction="none")
|
||||
loss = (1 - alpha) * xent + alpha * kl
|
||||
return loss
|
||||
|
||||
def patch_whitening(data, patch_size=(3, 3)):
|
||||
h, w = patch_size
|
||||
c = data.size(1)
|
||||
patches = data.unfold(2, h, 1).unfold(3, w, 1)
|
||||
patches = patches.transpose(1, 3).reshape(-1, c, h, w).to(torch.float32)
|
||||
n, c, h, w = patches.shape
|
||||
X = patches.reshape(n, c * h * w)
|
||||
X = X / (X.size(0) - 1) ** 0.5
|
||||
covariance = X.t() @ X
|
||||
eigenvalues, eigenvectors = torch.linalg.eigh(covariance)
|
||||
eigenvalues = eigenvalues.flip(0)
|
||||
eigenvectors = eigenvectors.t().reshape(c * h * w, c, h, w).flip(0)
|
||||
return eigenvectors / torch.sqrt(eigenvalues + 1e-2).view(-1, 1, 1, 1)
|
||||
1069
poetry.lock
generated
Normal file
1069
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
29
pyproject.toml
Normal file
29
pyproject.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[tool.poetry]
|
||||
name = "palisade_he_cnn"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
"Vikram Saraph <vikram.saraph@jhuapl.edu>",
|
||||
"Vivian Maloney <vivian.maloney@jhuapl.edu>",
|
||||
"Freddy Obrecht <freddy.obrecht@jhuapl.edu>",
|
||||
"Kate Tallaksen <kate.tallaksen@jhuapl.edu>",
|
||||
"Prathibha Rama <prathibha.rama@jhuapl.edu>"
|
||||
]
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
{include = "palisade_he_cnn"}
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
torch = "^1.13.1"
|
||||
torchvision = "^0.14.1"
|
||||
optuna = "^3.2.0"
|
||||
joblib = "^1.3.2"
|
||||
pytest = "^7.4.0"
|
||||
openfhe = {path = "../palisade-python/wheelhouse/OpenFHE-1.0.5a0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"}
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
Reference in New Issue
Block a user