Package lantern
Lantern: safer than a torch
The Lantern package contains utility funcitons to support formal verification of PyTorch modules by encoding the behavior of (certain) neural networks as Z3 constraints.
The 'public' API includes:
- round_model(model, sbits)
- as_z3(model, sort, prefix)
Expand source code
"""
Lantern: safer than a torch
The Lantern package contains utility funcitons to support formal
verification of PyTorch modules by encoding the behavior of (certain)
neural networks as Z3 constraints.
The 'public' API includes:
- round_model(model, sbits)
- as_z3(model, sort, prefix)
"""
# Copyright 2020 The Johns Hopkins University Applied Physics Laboratory LLC
# All rights reserved.
#
# Licensed under the 3-Caluse BSD License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import struct
from collections import OrderedDict
from functools import reduce
import torch.nn as nn
import z3
def truncate_double(f, sbits=52):
"""
Truncate the significand/mantissa precision of f to number of sbits.
Note that f is expected to be a Python float (double precision).
sbits=52 is a no-op
"""
assert((sbits <= 52) and (sbits >= 0))
original = float(f)
int_cast = struct.unpack(">q", struct.pack(">d", original))[0]
truncated_int = ((int_cast >> (52 - sbits)) << (52 - sbits))
truncated = float(struct.unpack(">d", struct.pack(">q", truncated_int))[0])
return truncated
def round_model(model, sbits=52):
"""
Return a new model where every value in the original state dict has
had its fractional precision reduced to number of sbits. Exponent
part remains the same (11 bits) so the result can be returned as
a Python float.
Note that sbits=52 is a no-op. Single precision sbits=23; half=10
"""
new_model = copy.deepcopy(model)
for t in new_model.state_dict().values():
t.apply_(lambda f: truncate_double(f, sbits))
return new_model
def encode_relu(x, y):
"""
Returns a list of z3 constraints corresponding to:
y == relu(x)
Where: x, y are lists of z3 variables
"""
assert len(x) == len(y)
constraints = []
for x_i, y_i in zip(x, y):
lhs = y_i
rhs = z3.If(x_i >= 0, x_i, 0)
constraint = z3.simplify(lhs == rhs)
constraints.append(constraint)
return constraints
def encode_hardtanh(x, y, min_val=-1, max_val=1):
"""
Returns a list of z3 constraints corresponding to:
y == hardtanh(x, min_val=-1, max_val=1)
Where: x, y are lists of z3 variables
"""
assert len(x) == len(y)
assert min_val < max_val
constraints = []
for x_i, y_i in zip(x, y):
lhs = y_i
rhs = z3.If(x_i <= min_val,
min_val,
z3.If(x_i <= max_val,
x_i,
max_val))
constraint = z3.simplify(lhs == rhs)
constraints.append(constraint)
return constraints
def hacky_sum(coll):
"""
Because z3.Sum() doesn't work on FP sorts
"""
if len(coll) == 0:
return 0
elif len(coll) == 1:
return coll[0]
else:
return reduce(lambda x, y: x + y, coll)
def encode_linear(W, b, x, y):
"""
Returns a list of z3 constraints corresponding to:
y == W * x + b
Where: x, y are lists of z3 variables,
W, b are pytorch tensors
"""
m, n = W.size()
assert m == len(b)
assert n == len(x)
assert m == len(y)
assert m >= 1 and n >= 1
constraints = []
for i in range(m):
lhs = y[i]
rhs = hacky_sum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item()
constraint = z3.simplify(lhs == rhs)
constraints.append(constraint)
return constraints
def const_vector(prefix, length, sort=z3.RealSort()):
"""
Returns a list of z3 constants of given sort.
e.g. const_vector("foo", 5, z3.FloatSingle())
Returns a list of 5 FP
"""
names = [prefix + "__" + str(i) for i in range(length)]
return z3.Consts(names, sort)
def as_z3(model, sort=z3.RealSort(), prefix=""):
"""
Calculate z3 constraints from a torch.nn.Sequential model.
Returns (constraints, z3_input, z3_output) where:
- constraints is a list of z3 constraints for the entire network
- z3_input is z3.RealVector representing the input to the network
- z3_output is a z3.RealVector representing output of the network
There are several caveats:
- The model must be a torch Sequential
- The first layer must be Linear
- Dropout layers are ignored
- Identity layers are ignored
- Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity
- An Exception is raised on any other type of layer
sort defaults to z3.RealSort(), but floating point sorts are
permitted; note that z3.FloatSingle() matches the default behavior
of PyTorch more accurately (but has different performance
characteristics compared to a real arithmetic theory
prefix is an optional string prefix for the generated z3 variables
"""
assert isinstance(model, nn.Sequential)
modules = OrderedDict(model.named_modules())
# named_modules() has ("" -> the entire net) as first key/val pair; remove
modules.pop("")
constraints = []
first_vector = None
previous_vector = None
for name in modules:
module = modules[name]
if isinstance(module, nn.Linear):
W, b = module.parameters()
in_vector = previous_vector
if in_vector is None:
in_vector = const_vector("{}_lin{}_in".format(prefix, name),
module.in_features, sort)
first_vector = in_vector
out_vector = const_vector("{}_lin{}_out".format(prefix, name),
module.out_features, sort)
constraints.extend(encode_linear(W, b, in_vector, out_vector))
elif isinstance(module, nn.ReLU):
in_vector = previous_vector
if in_vector is None:
raise ValueError("First layer must be linear")
out_vector = const_vector("{}_relu{}_out".format(prefix, name),
len(in_vector), sort)
constraints.extend(encode_relu(in_vector, out_vector))
elif isinstance(module, nn.Hardtanh):
in_vector = previous_vector
if in_vector is None:
raise ValueError("First layer must be linear")
out_vector = const_vector("{}_tanh{}_out".format(prefix, name),
len(in_vector), sort)
constraints.extend(encode_hardtanh(in_vector, out_vector,
module.min_val, module.max_val))
elif isinstance(module, nn.Dropout):
pass
elif isinstance(module, nn.Identity):
pass
else:
raise ValueError("Don't know how to convert module: {}".format(module))
previous_vector = out_vector
# previous_vector is vector associated with last layer output
return (constraints, first_vector, previous_vector)
Functions
def as_z3(model, sort=Real, prefix='')-
Calculate z3 constraints from a torch.nn.Sequential model.
Returns (constraints, z3_input, z3_output) where:
- constraints is a list of z3 constraints for the entire network
- z3_input is z3.RealVector representing the input to the network
- z3_output is a z3.RealVector representing output of the network
There are several caveats:
- The model must be a torch Sequential
- The first layer must be Linear
- Dropout layers are ignored
- Identity layers are ignored
- Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity
- An Exception is raised on any other type of layer
sort defaults to z3.RealSort(), but floating point sorts are permitted; note that z3.FloatSingle() matches the default behavior of PyTorch more accurately (but has different performance characteristics compared to a real arithmetic theory
prefix is an optional string prefix for the generated z3 variables
Expand source code
def as_z3(model, sort=z3.RealSort(), prefix=""): """ Calculate z3 constraints from a torch.nn.Sequential model. Returns (constraints, z3_input, z3_output) where: - constraints is a list of z3 constraints for the entire network - z3_input is z3.RealVector representing the input to the network - z3_output is a z3.RealVector representing output of the network There are several caveats: - The model must be a torch Sequential - The first layer must be Linear - Dropout layers are ignored - Identity layers are ignored - Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity - An Exception is raised on any other type of layer sort defaults to z3.RealSort(), but floating point sorts are permitted; note that z3.FloatSingle() matches the default behavior of PyTorch more accurately (but has different performance characteristics compared to a real arithmetic theory prefix is an optional string prefix for the generated z3 variables """ assert isinstance(model, nn.Sequential) modules = OrderedDict(model.named_modules()) # named_modules() has ("" -> the entire net) as first key/val pair; remove modules.pop("") constraints = [] first_vector = None previous_vector = None for name in modules: module = modules[name] if isinstance(module, nn.Linear): W, b = module.parameters() in_vector = previous_vector if in_vector is None: in_vector = const_vector("{}_lin{}_in".format(prefix, name), module.in_features, sort) first_vector = in_vector out_vector = const_vector("{}_lin{}_out".format(prefix, name), module.out_features, sort) constraints.extend(encode_linear(W, b, in_vector, out_vector)) elif isinstance(module, nn.ReLU): in_vector = previous_vector if in_vector is None: raise ValueError("First layer must be linear") out_vector = const_vector("{}_relu{}_out".format(prefix, name), len(in_vector), sort) constraints.extend(encode_relu(in_vector, out_vector)) elif isinstance(module, nn.Hardtanh): in_vector = previous_vector if in_vector is None: raise ValueError("First layer must be linear") out_vector = const_vector("{}_tanh{}_out".format(prefix, name), len(in_vector), sort) constraints.extend(encode_hardtanh(in_vector, out_vector, module.min_val, module.max_val)) elif isinstance(module, nn.Dropout): pass elif isinstance(module, nn.Identity): pass else: raise ValueError("Don't know how to convert module: {}".format(module)) previous_vector = out_vector # previous_vector is vector associated with last layer output return (constraints, first_vector, previous_vector) def const_vector(prefix, length, sort=Real)-
Returns a list of z3 constants of given sort.
e.g. const_vector("foo", 5, z3.FloatSingle()) Returns a list of 5 FP
Expand source code
def const_vector(prefix, length, sort=z3.RealSort()): """ Returns a list of z3 constants of given sort. e.g. const_vector("foo", 5, z3.FloatSingle()) Returns a list of 5 FP """ names = [prefix + "__" + str(i) for i in range(length)] return z3.Consts(names, sort) def encode_hardtanh(x, y, min_val=-1, max_val=1)-
Returns a list of z3 constraints corresponding to:
y == hardtanh(x, min_val=-1, max_val=1)
Where: x, y are lists of z3 variables
Expand source code
def encode_hardtanh(x, y, min_val=-1, max_val=1): """ Returns a list of z3 constraints corresponding to: y == hardtanh(x, min_val=-1, max_val=1) Where: x, y are lists of z3 variables """ assert len(x) == len(y) assert min_val < max_val constraints = [] for x_i, y_i in zip(x, y): lhs = y_i rhs = z3.If(x_i <= min_val, min_val, z3.If(x_i <= max_val, x_i, max_val)) constraint = z3.simplify(lhs == rhs) constraints.append(constraint) return constraints def encode_linear(W, b, x, y)-
Returns a list of z3 constraints corresponding to:
y == W * x + b
Where: x, y are lists of z3 variables, W, b are pytorch tensors
Expand source code
def encode_linear(W, b, x, y): """ Returns a list of z3 constraints corresponding to: y == W * x + b Where: x, y are lists of z3 variables, W, b are pytorch tensors """ m, n = W.size() assert m == len(b) assert n == len(x) assert m == len(y) assert m >= 1 and n >= 1 constraints = [] for i in range(m): lhs = y[i] rhs = hacky_sum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item() constraint = z3.simplify(lhs == rhs) constraints.append(constraint) return constraints def encode_relu(x, y)-
Returns a list of z3 constraints corresponding to:
y == relu(x)
Where: x, y are lists of z3 variables
Expand source code
def encode_relu(x, y): """ Returns a list of z3 constraints corresponding to: y == relu(x) Where: x, y are lists of z3 variables """ assert len(x) == len(y) constraints = [] for x_i, y_i in zip(x, y): lhs = y_i rhs = z3.If(x_i >= 0, x_i, 0) constraint = z3.simplify(lhs == rhs) constraints.append(constraint) return constraints def hacky_sum(coll)-
Because z3.Sum() doesn't work on FP sorts
Expand source code
def hacky_sum(coll): """ Because z3.Sum() doesn't work on FP sorts """ if len(coll) == 0: return 0 elif len(coll) == 1: return coll[0] else: return reduce(lambda x, y: x + y, coll) def round_model(model, sbits=52)-
Return a new model where every value in the original state dict has had its fractional precision reduced to number of sbits. Exponent part remains the same (11 bits) so the result can be returned as a Python float.
Note that sbits=52 is a no-op. Single precision sbits=23; half=10
Expand source code
def round_model(model, sbits=52): """ Return a new model where every value in the original state dict has had its fractional precision reduced to number of sbits. Exponent part remains the same (11 bits) so the result can be returned as a Python float. Note that sbits=52 is a no-op. Single precision sbits=23; half=10 """ new_model = copy.deepcopy(model) for t in new_model.state_dict().values(): t.apply_(lambda f: truncate_double(f, sbits)) return new_model def truncate_double(f, sbits=52)-
Truncate the significand/mantissa precision of f to number of sbits.
Note that f is expected to be a Python float (double precision).
sbits=52 is a no-op
Expand source code
def truncate_double(f, sbits=52): """ Truncate the significand/mantissa precision of f to number of sbits. Note that f is expected to be a Python float (double precision). sbits=52 is a no-op """ assert((sbits <= 52) and (sbits >= 0)) original = float(f) int_cast = struct.unpack(">q", struct.pack(">d", original))[0] truncated_int = ((int_cast >> (52 - sbits)) << (52 - sbits)) truncated = float(struct.unpack(">d", struct.pack(">q", truncated_int))[0]) return truncated