mirror of
https://github.com/JHUAPL/lantern-smt.git
synced 2026-01-09 18:17:56 -05:00
651 lines
25 KiB
HTML
651 lines
25 KiB
HTML
<!doctype html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
|
|
<meta name="generator" content="pdoc 0.8.1" />
|
|
<title>lantern API documentation</title>
|
|
<meta name="description" content="Lantern: safer than a torch …" />
|
|
<link href='https://cdnjs.cloudflare.com/ajax/libs/normalize/8.0.0/normalize.min.css' rel='stylesheet'>
|
|
<link href='https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/8.0.0/sanitize.min.css' rel='stylesheet'>
|
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/github.min.css" rel="stylesheet">
|
|
<style>.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
|
|
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
|
|
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
|
|
</head>
|
|
<body>
|
|
<main>
|
|
<article id="content">
|
|
<header>
|
|
<h1 class="title">Package <code>lantern</code></h1>
|
|
</header>
|
|
<section id="section-intro">
|
|
<p>Lantern: safer than a torch</p>
|
|
<p>The Lantern package contains utility funcitons to support formal
|
|
verification of PyTorch modules by encoding the behavior of (certain)
|
|
neural networks as Z3 constraints.</p>
|
|
<p>The 'public' API includes:</p>
|
|
<ul>
|
|
<li>round_model(model, sbits)</li>
|
|
<li>as_z3(model, sort, prefix)</li>
|
|
</ul>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">"""
|
|
Lantern: safer than a torch
|
|
|
|
The Lantern package contains utility funcitons to support formal
|
|
verification of PyTorch modules by encoding the behavior of (certain)
|
|
neural networks as Z3 constraints.
|
|
|
|
The 'public' API includes:
|
|
|
|
- round_model(model, sbits)
|
|
- as_z3(model, sort, prefix)
|
|
"""
|
|
|
|
# Copyright 2020 The Johns Hopkins University Applied Physics Laboratory LLC
|
|
# All rights reserved.
|
|
#
|
|
# Licensed under the 3-Caluse BSD License (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://opensource.org/licenses/BSD-3-Clause
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import copy
|
|
import struct
|
|
from collections import OrderedDict
|
|
from functools import reduce
|
|
|
|
import torch.nn as nn
|
|
import z3
|
|
|
|
|
|
def truncate_double(f, sbits=52):
|
|
"""
|
|
Truncate the significand/mantissa precision of f to number of sbits.
|
|
|
|
Note that f is expected to be a Python float (double precision).
|
|
|
|
sbits=52 is a no-op
|
|
"""
|
|
assert((sbits <= 52) and (sbits >= 0))
|
|
|
|
original = float(f)
|
|
int_cast = struct.unpack(">q", struct.pack(">d", original))[0]
|
|
truncated_int = ((int_cast >> (52 - sbits)) << (52 - sbits))
|
|
truncated = float(struct.unpack(">d", struct.pack(">q", truncated_int))[0])
|
|
|
|
return truncated
|
|
|
|
|
|
def round_model(model, sbits=52):
|
|
"""
|
|
Return a new model where every value in the original state dict has
|
|
had its fractional precision reduced to number of sbits. Exponent
|
|
part remains the same (11 bits) so the result can be returned as
|
|
a Python float.
|
|
|
|
Note that sbits=52 is a no-op. Single precision sbits=23; half=10
|
|
"""
|
|
new_model = copy.deepcopy(model)
|
|
|
|
for t in new_model.state_dict().values():
|
|
t.apply_(lambda f: truncate_double(f, sbits))
|
|
|
|
return new_model
|
|
|
|
|
|
def encode_relu(x, y):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == relu(x)
|
|
|
|
Where: x, y are lists of z3 variables
|
|
"""
|
|
assert len(x) == len(y)
|
|
|
|
constraints = []
|
|
for x_i, y_i in zip(x, y):
|
|
lhs = y_i
|
|
rhs = z3.If(x_i >= 0, x_i, 0)
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints
|
|
|
|
|
|
def encode_hardtanh(x, y, min_val=-1, max_val=1):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == hardtanh(x, min_val=-1, max_val=1)
|
|
|
|
Where: x, y are lists of z3 variables
|
|
"""
|
|
assert len(x) == len(y)
|
|
assert min_val < max_val
|
|
|
|
constraints = []
|
|
for x_i, y_i in zip(x, y):
|
|
lhs = y_i
|
|
rhs = z3.If(x_i <= min_val,
|
|
min_val,
|
|
z3.If(x_i <= max_val,
|
|
x_i,
|
|
max_val))
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints
|
|
|
|
|
|
def hacky_sum(coll):
|
|
"""
|
|
Because z3.Sum() doesn't work on FP sorts
|
|
"""
|
|
if len(coll) == 0:
|
|
return 0
|
|
elif len(coll) == 1:
|
|
return coll[0]
|
|
else:
|
|
return reduce(lambda x, y: x + y, coll)
|
|
|
|
|
|
def encode_linear(W, b, x, y):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == W * x + b
|
|
|
|
Where: x, y are lists of z3 variables,
|
|
W, b are pytorch tensors
|
|
"""
|
|
m, n = W.size()
|
|
assert m == len(b)
|
|
assert n == len(x)
|
|
assert m == len(y)
|
|
assert m >= 1 and n >= 1
|
|
|
|
constraints = []
|
|
for i in range(m):
|
|
lhs = y[i]
|
|
rhs = hacky_sum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item()
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints
|
|
|
|
|
|
def const_vector(prefix, length, sort=z3.RealSort()):
|
|
"""
|
|
Returns a list of z3 constants of given sort.
|
|
|
|
e.g. const_vector("foo", 5, z3.FloatSingle())
|
|
Returns a list of 5 FP
|
|
"""
|
|
names = [prefix + "__" + str(i) for i in range(length)]
|
|
return z3.Consts(names, sort)
|
|
|
|
|
|
def as_z3(model, sort=z3.RealSort(), prefix=""):
|
|
"""
|
|
Calculate z3 constraints from a torch.nn.Sequential model.
|
|
|
|
Returns (constraints, z3_input, z3_output) where:
|
|
|
|
- constraints is a list of z3 constraints for the entire network
|
|
- z3_input is z3.RealVector representing the input to the network
|
|
- z3_output is a z3.RealVector representing output of the network
|
|
|
|
There are several caveats:
|
|
|
|
- The model must be a torch Sequential
|
|
- The first layer must be Linear
|
|
- Dropout layers are ignored
|
|
- Identity layers are ignored
|
|
- Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity
|
|
- An Exception is raised on any other type of layer
|
|
|
|
sort defaults to z3.RealSort(), but floating point sorts are
|
|
permitted; note that z3.FloatSingle() matches the default behavior
|
|
of PyTorch more accurately (but has different performance
|
|
characteristics compared to a real arithmetic theory
|
|
|
|
prefix is an optional string prefix for the generated z3 variables
|
|
"""
|
|
assert isinstance(model, nn.Sequential)
|
|
|
|
modules = OrderedDict(model.named_modules())
|
|
|
|
# named_modules() has ("" -> the entire net) as first key/val pair; remove
|
|
modules.pop("")
|
|
|
|
constraints = []
|
|
first_vector = None
|
|
previous_vector = None
|
|
for name in modules:
|
|
module = modules[name]
|
|
|
|
if isinstance(module, nn.Linear):
|
|
W, b = module.parameters()
|
|
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
in_vector = const_vector("{}_lin{}_in".format(prefix, name),
|
|
module.in_features, sort)
|
|
first_vector = in_vector
|
|
|
|
out_vector = const_vector("{}_lin{}_out".format(prefix, name),
|
|
module.out_features, sort)
|
|
|
|
constraints.extend(encode_linear(W, b, in_vector, out_vector))
|
|
|
|
elif isinstance(module, nn.ReLU):
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
raise ValueError("First layer must be linear")
|
|
|
|
out_vector = const_vector("{}_relu{}_out".format(prefix, name),
|
|
len(in_vector), sort)
|
|
|
|
constraints.extend(encode_relu(in_vector, out_vector))
|
|
|
|
elif isinstance(module, nn.Hardtanh):
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
raise ValueError("First layer must be linear")
|
|
|
|
out_vector = const_vector("{}_tanh{}_out".format(prefix, name),
|
|
len(in_vector), sort)
|
|
|
|
constraints.extend(encode_hardtanh(in_vector, out_vector,
|
|
module.min_val, module.max_val))
|
|
|
|
elif isinstance(module, nn.Dropout):
|
|
pass
|
|
elif isinstance(module, nn.Identity):
|
|
pass
|
|
else:
|
|
raise ValueError("Don't know how to convert module: {}".format(module))
|
|
|
|
previous_vector = out_vector
|
|
|
|
|
|
# previous_vector is vector associated with last layer output
|
|
return (constraints, first_vector, previous_vector)</code></pre>
|
|
</details>
|
|
</section>
|
|
<section>
|
|
</section>
|
|
<section>
|
|
</section>
|
|
<section>
|
|
<h2 class="section-title" id="header-functions">Functions</h2>
|
|
<dl>
|
|
<dt id="lantern.as_z3"><code class="name flex">
|
|
<span>def <span class="ident">as_z3</span></span>(<span>model, sort=Real, prefix='')</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Calculate z3 constraints from a torch.nn.Sequential model.</p>
|
|
<p>Returns (constraints, z3_input, z3_output) where:</p>
|
|
<ul>
|
|
<li>constraints is a list of z3 constraints for the entire network</li>
|
|
<li>z3_input is z3.RealVector representing the input to the network</li>
|
|
<li>z3_output is a z3.RealVector representing output of the network</li>
|
|
</ul>
|
|
<p>There are several caveats:</p>
|
|
<ul>
|
|
<li>The model must be a torch Sequential</li>
|
|
<li>The first layer must be Linear</li>
|
|
<li>Dropout layers are ignored</li>
|
|
<li>Identity layers are ignored</li>
|
|
<li>Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity</li>
|
|
<li>An Exception is raised on any other type of layer</li>
|
|
</ul>
|
|
<p>sort defaults to z3.RealSort(), but floating point sorts are
|
|
permitted; note that z3.FloatSingle() matches the default behavior
|
|
of PyTorch more accurately (but has different performance
|
|
characteristics compared to a real arithmetic theory</p>
|
|
<p>prefix is an optional string prefix for the generated z3 variables</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def as_z3(model, sort=z3.RealSort(), prefix=""):
|
|
"""
|
|
Calculate z3 constraints from a torch.nn.Sequential model.
|
|
|
|
Returns (constraints, z3_input, z3_output) where:
|
|
|
|
- constraints is a list of z3 constraints for the entire network
|
|
- z3_input is z3.RealVector representing the input to the network
|
|
- z3_output is a z3.RealVector representing output of the network
|
|
|
|
There are several caveats:
|
|
|
|
- The model must be a torch Sequential
|
|
- The first layer must be Linear
|
|
- Dropout layers are ignored
|
|
- Identity layers are ignored
|
|
- Supported layers are: Linear, ReLU, Hardtanh, Dropout, Identity
|
|
- An Exception is raised on any other type of layer
|
|
|
|
sort defaults to z3.RealSort(), but floating point sorts are
|
|
permitted; note that z3.FloatSingle() matches the default behavior
|
|
of PyTorch more accurately (but has different performance
|
|
characteristics compared to a real arithmetic theory
|
|
|
|
prefix is an optional string prefix for the generated z3 variables
|
|
"""
|
|
assert isinstance(model, nn.Sequential)
|
|
|
|
modules = OrderedDict(model.named_modules())
|
|
|
|
# named_modules() has ("" -> the entire net) as first key/val pair; remove
|
|
modules.pop("")
|
|
|
|
constraints = []
|
|
first_vector = None
|
|
previous_vector = None
|
|
for name in modules:
|
|
module = modules[name]
|
|
|
|
if isinstance(module, nn.Linear):
|
|
W, b = module.parameters()
|
|
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
in_vector = const_vector("{}_lin{}_in".format(prefix, name),
|
|
module.in_features, sort)
|
|
first_vector = in_vector
|
|
|
|
out_vector = const_vector("{}_lin{}_out".format(prefix, name),
|
|
module.out_features, sort)
|
|
|
|
constraints.extend(encode_linear(W, b, in_vector, out_vector))
|
|
|
|
elif isinstance(module, nn.ReLU):
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
raise ValueError("First layer must be linear")
|
|
|
|
out_vector = const_vector("{}_relu{}_out".format(prefix, name),
|
|
len(in_vector), sort)
|
|
|
|
constraints.extend(encode_relu(in_vector, out_vector))
|
|
|
|
elif isinstance(module, nn.Hardtanh):
|
|
in_vector = previous_vector
|
|
if in_vector is None:
|
|
raise ValueError("First layer must be linear")
|
|
|
|
out_vector = const_vector("{}_tanh{}_out".format(prefix, name),
|
|
len(in_vector), sort)
|
|
|
|
constraints.extend(encode_hardtanh(in_vector, out_vector,
|
|
module.min_val, module.max_val))
|
|
|
|
elif isinstance(module, nn.Dropout):
|
|
pass
|
|
elif isinstance(module, nn.Identity):
|
|
pass
|
|
else:
|
|
raise ValueError("Don't know how to convert module: {}".format(module))
|
|
|
|
previous_vector = out_vector
|
|
|
|
|
|
# previous_vector is vector associated with last layer output
|
|
return (constraints, first_vector, previous_vector)</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.const_vector"><code class="name flex">
|
|
<span>def <span class="ident">const_vector</span></span>(<span>prefix, length, sort=Real)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Returns a list of z3 constants of given sort.</p>
|
|
<p>e.g. const_vector("foo", 5, z3.FloatSingle())
|
|
Returns a list of 5 FP</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def const_vector(prefix, length, sort=z3.RealSort()):
|
|
"""
|
|
Returns a list of z3 constants of given sort.
|
|
|
|
e.g. const_vector("foo", 5, z3.FloatSingle())
|
|
Returns a list of 5 FP
|
|
"""
|
|
names = [prefix + "__" + str(i) for i in range(length)]
|
|
return z3.Consts(names, sort)</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.encode_hardtanh"><code class="name flex">
|
|
<span>def <span class="ident">encode_hardtanh</span></span>(<span>x, y, min_val=-1, max_val=1)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Returns a list of z3 constraints corresponding to:</p>
|
|
<p>y == hardtanh(x, min_val=-1, max_val=1)</p>
|
|
<p>Where: x, y are lists of z3 variables</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def encode_hardtanh(x, y, min_val=-1, max_val=1):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == hardtanh(x, min_val=-1, max_val=1)
|
|
|
|
Where: x, y are lists of z3 variables
|
|
"""
|
|
assert len(x) == len(y)
|
|
assert min_val < max_val
|
|
|
|
constraints = []
|
|
for x_i, y_i in zip(x, y):
|
|
lhs = y_i
|
|
rhs = z3.If(x_i <= min_val,
|
|
min_val,
|
|
z3.If(x_i <= max_val,
|
|
x_i,
|
|
max_val))
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.encode_linear"><code class="name flex">
|
|
<span>def <span class="ident">encode_linear</span></span>(<span>W, b, x, y)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Returns a list of z3 constraints corresponding to:</p>
|
|
<p>y == W * x + b</p>
|
|
<p>Where: x, y are lists of z3 variables,
|
|
W, b are pytorch tensors</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def encode_linear(W, b, x, y):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == W * x + b
|
|
|
|
Where: x, y are lists of z3 variables,
|
|
W, b are pytorch tensors
|
|
"""
|
|
m, n = W.size()
|
|
assert m == len(b)
|
|
assert n == len(x)
|
|
assert m == len(y)
|
|
assert m >= 1 and n >= 1
|
|
|
|
constraints = []
|
|
for i in range(m):
|
|
lhs = y[i]
|
|
rhs = hacky_sum([W[i, j].item() * x[j] for j in range(n)]) + b[i].item()
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.encode_relu"><code class="name flex">
|
|
<span>def <span class="ident">encode_relu</span></span>(<span>x, y)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Returns a list of z3 constraints corresponding to:</p>
|
|
<p>y == relu(x)</p>
|
|
<p>Where: x, y are lists of z3 variables</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def encode_relu(x, y):
|
|
"""
|
|
Returns a list of z3 constraints corresponding to:
|
|
|
|
y == relu(x)
|
|
|
|
Where: x, y are lists of z3 variables
|
|
"""
|
|
assert len(x) == len(y)
|
|
|
|
constraints = []
|
|
for x_i, y_i in zip(x, y):
|
|
lhs = y_i
|
|
rhs = z3.If(x_i >= 0, x_i, 0)
|
|
constraint = z3.simplify(lhs == rhs)
|
|
constraints.append(constraint)
|
|
|
|
return constraints</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.hacky_sum"><code class="name flex">
|
|
<span>def <span class="ident">hacky_sum</span></span>(<span>coll)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Because z3.Sum() doesn't work on FP sorts</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def hacky_sum(coll):
|
|
"""
|
|
Because z3.Sum() doesn't work on FP sorts
|
|
"""
|
|
if len(coll) == 0:
|
|
return 0
|
|
elif len(coll) == 1:
|
|
return coll[0]
|
|
else:
|
|
return reduce(lambda x, y: x + y, coll)</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.round_model"><code class="name flex">
|
|
<span>def <span class="ident">round_model</span></span>(<span>model, sbits=52)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Return a new model where every value in the original state dict has
|
|
had its fractional precision reduced to number of sbits. Exponent
|
|
part remains the same (11 bits) so the result can be returned as
|
|
a Python float.</p>
|
|
<p>Note that sbits=52 is a no-op. Single precision sbits=23; half=10</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def round_model(model, sbits=52):
|
|
"""
|
|
Return a new model where every value in the original state dict has
|
|
had its fractional precision reduced to number of sbits. Exponent
|
|
part remains the same (11 bits) so the result can be returned as
|
|
a Python float.
|
|
|
|
Note that sbits=52 is a no-op. Single precision sbits=23; half=10
|
|
"""
|
|
new_model = copy.deepcopy(model)
|
|
|
|
for t in new_model.state_dict().values():
|
|
t.apply_(lambda f: truncate_double(f, sbits))
|
|
|
|
return new_model</code></pre>
|
|
</details>
|
|
</dd>
|
|
<dt id="lantern.truncate_double"><code class="name flex">
|
|
<span>def <span class="ident">truncate_double</span></span>(<span>f, sbits=52)</span>
|
|
</code></dt>
|
|
<dd>
|
|
<div class="desc"><p>Truncate the significand/mantissa precision of f to number of sbits.</p>
|
|
<p>Note that f is expected to be a Python float (double precision).</p>
|
|
<p>sbits=52 is a no-op</p></div>
|
|
<details class="source">
|
|
<summary>
|
|
<span>Expand source code</span>
|
|
</summary>
|
|
<pre><code class="python">def truncate_double(f, sbits=52):
|
|
"""
|
|
Truncate the significand/mantissa precision of f to number of sbits.
|
|
|
|
Note that f is expected to be a Python float (double precision).
|
|
|
|
sbits=52 is a no-op
|
|
"""
|
|
assert((sbits <= 52) and (sbits >= 0))
|
|
|
|
original = float(f)
|
|
int_cast = struct.unpack(">q", struct.pack(">d", original))[0]
|
|
truncated_int = ((int_cast >> (52 - sbits)) << (52 - sbits))
|
|
truncated = float(struct.unpack(">d", struct.pack(">q", truncated_int))[0])
|
|
|
|
return truncated</code></pre>
|
|
</details>
|
|
</dd>
|
|
</dl>
|
|
</section>
|
|
<section>
|
|
</section>
|
|
</article>
|
|
<nav id="sidebar">
|
|
<h1>Index</h1>
|
|
<div class="toc">
|
|
<ul></ul>
|
|
</div>
|
|
<ul id="index">
|
|
<li><h3><a href="#header-functions">Functions</a></h3>
|
|
<ul class="two-column">
|
|
<li><code><a title="lantern.as_z3" href="#lantern.as_z3">as_z3</a></code></li>
|
|
<li><code><a title="lantern.const_vector" href="#lantern.const_vector">const_vector</a></code></li>
|
|
<li><code><a title="lantern.encode_hardtanh" href="#lantern.encode_hardtanh">encode_hardtanh</a></code></li>
|
|
<li><code><a title="lantern.encode_linear" href="#lantern.encode_linear">encode_linear</a></code></li>
|
|
<li><code><a title="lantern.encode_relu" href="#lantern.encode_relu">encode_relu</a></code></li>
|
|
<li><code><a title="lantern.hacky_sum" href="#lantern.hacky_sum">hacky_sum</a></code></li>
|
|
<li><code><a title="lantern.round_model" href="#lantern.round_model">round_model</a></code></li>
|
|
<li><code><a title="lantern.truncate_double" href="#lantern.truncate_double">truncate_double</a></code></li>
|
|
</ul>
|
|
</li>
|
|
</ul>
|
|
</nav>
|
|
</main>
|
|
<footer id="footer">
|
|
<p>Generated by <a href="https://pdoc3.github.io/pdoc"><cite>pdoc</cite> 0.8.1</a>.</p>
|
|
</footer>
|
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/highlight.min.js"></script>
|
|
<script>hljs.initHighlightingOnLoad()</script>
|
|
</body>
|
|
</html> |