stream blobs directly from file to the wire

enforce same cwk/civ inputs for c4 and c6
catch panics and terminate session gracefully
only input labels should be encrypted to disk
move ghash,key_manager,metadata into a package
mark with references to steps in the write-up
This commit is contained in:
themighty1
2022-01-20 10:13:42 +03:00
parent 33bc276031
commit f7f0ad5a24
15 changed files with 1815 additions and 1854 deletions

2
README
View File

@@ -8,7 +8,7 @@ It is primarily intended to be run inside a sandboxed AWS EC2 instance (https://
git clone --recurse-submodules https://github.com/tlsnotary/server
3. Compile:
cd server
cd server/src
go mod init notary
go get github.com/bwesterb/go-ristretto@b51b4774df9150ea7d7616f76e77f745a464bbe3
go get github.com/roasbeef/go-go-gadget-paillier@14f1f86b60008ece97b6233ed246373e555fc79f

View File

@@ -1,369 +1,123 @@
package evaluator
import (
"bytes"
"encoding/binary"
"log"
"math"
"math/rand"
"notary/garbler"
"notary/meta"
u "notary/utils"
"time"
"github.com/bwesterb/go-ristretto"
)
type Evaluator struct {
g *garbler.Garbler
// fixed inputs for each circuit (circuit count starts at 1)
FixedInputs [][]int
// OT for bits 0/1 format: ({k:[]byte, B:[]byte})
OT0 []OTmap
OT1 []OTmap
A []byte // client-garbler's A
fixedLabels [][][]byte
OTFixedK [][]byte
// the total amount of c6 circuit executions for this session
C6Count int
// all circuits, count starts with 1 to avoid confusion
// they are meant to be read-only for evaluator
meta []*meta.Circuit
ttBlobs [][]byte // truth table blobs for each circuit
olBlobs [][]byte // output labels blobs for each circuit
nonFixedOTBits [][]OTmap
Salt [][]byte // commitment salt for each circuit
CommitHash [][]byte // hash of output for each circuit
}
type OTmap struct {
K []byte
B []byte
idx int
func (e *Evaluator) Init(circuits []*meta.Circuit, c6Count int) {
e.C6Count = c6Count
e.meta = circuits
e.ttBlobs = make([][]byte, len(e.meta))
e.olBlobs = make([][]byte, len(e.meta))
}
func (e *Evaluator) Init(g *garbler.Garbler) {
e.g = g
e.FixedInputs = make([][]int, len(g.Cs))
e.fixedLabels = make([][][]byte, len(g.Cs))
e.ttBlobs = make([][]byte, len(g.Cs))
e.olBlobs = make([][]byte, len(g.Cs))
e.Salt = make([][]byte, len(g.Cs))
e.CommitHash = make([][]byte, len(g.Cs))
e.nonFixedOTBits = make([][]OTmap, len(g.Cs))
}
// SetFixedInputs is called after we know the amount of c6 circuits
// consult .casm file for each circuit for explanation what mask does what
func (e *Evaluator) SetFixedInputs() {
for i := 1; i < len(e.g.Cs); i++ {
c := e.g.Cs[i]
if i == 1 {
e.FixedInputs[1] = u.BytesToBits(c.Masks[1])
log.Println("e.FixedInputs[1] ", len(e.FixedInputs[1]))
}
if i == 2 {
e.FixedInputs[2] = u.BytesToBits(c.Masks[1])
log.Println("e.FixedInputs[2] ", len(e.FixedInputs[2]))
}
if i == 3 {
var allMasks []byte
allMasks = append(allMasks, c.Masks[6]...)
allMasks = append(allMasks, c.Masks[5]...)
allMasks = append(allMasks, c.Masks[4]...)
allMasks = append(allMasks, c.Masks[3]...)
allMasks = append(allMasks, c.Masks[2]...)
allMasks = append(allMasks, c.Masks[1]...)
e.FixedInputs[3] = u.BytesToBits(allMasks)
log.Println("e.FixedInputs[3] ", len(e.FixedInputs[3]))
}
if i == 4 {
var allMasks []byte
allMasks = append(allMasks, c.Masks[2]...)
allMasks = append(allMasks, c.Masks[1]...)
e.FixedInputs[4] = u.BytesToBits(allMasks)
log.Println("e.FixedInputs[4] ", len(e.FixedInputs[4]))
}
if i == 5 {
var allMasks []byte
allMasks = append(allMasks, e.g.Cs[3].Masks[4]...) // civ mask
allMasks = append(allMasks, e.g.Cs[3].Masks[2]...) // cwk mask
e.FixedInputs[5] = u.BytesToBits(allMasks)
log.Println("e.FixedInputs[5] ", len(e.FixedInputs[5]))
}
if i == 6 {
var allMasks []byte
for i := e.g.C6Count; i > 0; i-- {
allMasks = append(allMasks, e.g.Cs[6].Masks[i]...)
}
allMasks = append(allMasks, e.g.Cs[3].Masks[4]...) // civ mask
allMasks = append(allMasks, e.g.Cs[3].Masks[2]...) // cwk mask
e.FixedInputs[6] = u.BytesToBits(allMasks)
log.Println("e.FixedInputs[6] ", len(e.FixedInputs[6]))
}
}
}
// client's A for OT must be available at this point
func (e *Evaluator) PreComputeOT() []byte {
var allFixedInputs []int
allNonFixedInputsSize := 0
for i := 1; i < len(e.g.Cs); i++ {
allFixedInputs = append(allFixedInputs, e.FixedInputs[i]...)
allNonFixedInputsSize += e.g.Cs[i].NotaryNonFixedInputSize
}
log.Println("len(allFixedInputs)", len(allFixedInputs))
log.Println("allNonFixedInputsSize", allNonFixedInputsSize)
var buf [32]byte
copy(buf[:], e.A[:])
A := new(ristretto.Point)
A.SetBytes(&buf)
e.OTFixedK = nil
var OTFixedB [][]byte
for i := 0; i < len(allFixedInputs); i++ {
bit := allFixedInputs[i]
b := new(ristretto.Scalar).Rand()
B := new(ristretto.Point).ScalarMultBase(b)
if bit == 1 {
B = new(ristretto.Point).Add(A, B)
}
k := u.Generichash(16, new(ristretto.Point).ScalarMult(A, b).Bytes())
e.OTFixedK = append(e.OTFixedK, k)
OTFixedB = append(OTFixedB, B.Bytes())
}
// we prepare OT for 55% of 1s and 55% of 0s for all non-fixed inputs
// because we don't know in advance exactly how many 1s and 0s we'll have in the non-fixed
// inputs
e.OT0 = nil
e.OT1 = nil
for i := 0; i < int(math.Ceil(float64(allNonFixedInputsSize/2)*1.2))+3000; i++ {
b := new(ristretto.Scalar).Rand()
B := new(ristretto.Point).ScalarMultBase(b)
k := u.Generichash(16, new(ristretto.Point).ScalarMult(A, b).Bytes())
var m OTmap
m.K = k
m.B = B.Bytes()
e.OT0 = append(e.OT0, m)
}
for i := 0; i < int(math.Ceil(float64(allNonFixedInputsSize/2)*1.2))+3000; i++ {
b := new(ristretto.Scalar).Rand()
B := new(ristretto.Point).ScalarMultBase(b)
B = new(ristretto.Point).Add(A, B)
k := u.Generichash(16, new(ristretto.Point).ScalarMult(A, b).Bytes())
var m OTmap
m.K = k
m.B = B.Bytes()
e.OT1 = append(e.OT1, m)
}
log.Println("e.OT0/1 len is", len(e.OT0), len(e.OT1))
//send remaining OT in random sequence but remember the index in that sequence.
var OTNonFixedToSend []byte = nil
allOTLen := len(e.OT0) + len(e.OT1)
var idxSeen []int
for i := 0; i < allOTLen; i++ {
var ot *[]OTmap
rand.Seed(time.Now().UnixNano())
randIdx := rand.Intn(allOTLen)
if isIntInArray(randIdx, idxSeen) {
// this index was already seen, try again
i--
continue
}
idxSeen = append(idxSeen, randIdx)
if randIdx >= len(e.OT0) {
ot = &e.OT1
// adjust the index to become an OT1 index
randIdx = randIdx - len(e.OT0)
} else {
ot = &e.OT0
}
(*ot)[randIdx].idx = i
OTNonFixedToSend = append(OTNonFixedToSend, (*ot)[randIdx].B...)
}
var payload []byte
for i := 0; i < len(OTFixedB); i++ {
payload = append(payload, OTFixedB[i]...)
}
payload = append(payload, OTNonFixedToSend...)
log.Println("returning payload for garbler, size ", len(payload))
return payload
}
func (e *Evaluator) SetA(A []byte) {
e.A = A
}
func (e *Evaluator) ProcessEncryptedLabels(labelsBlob []byte) {
allFICount := 0 //count of all fixed inputs from all circuits
for i := 1; i < len(e.g.Cs); i++ {
allFICount += len(e.FixedInputs[i])
}
if len(labelsBlob) != allFICount*32 {
log.Println(len(labelsBlob), allFICount)
panic("len(labelsBlob) != allFICount*32")
}
idx := 0
for i := 1; i < len(e.g.Cs); i++ {
e.fixedLabels[i] = make([][]byte, len(e.FixedInputs[i]))
for j := 0; j < len(e.FixedInputs[i]); j++ {
bit := e.FixedInputs[i][j]
if bit != 0 && bit != 1 {
panic("bit != 0 || bit != 1")
}
e_ := labelsBlob[idx*32+16*bit : idx*32+16*bit+16]
inputLabel := u.Decrypt_generic(e_, e.OTFixedK[idx], 0)
idx += 1
e.fixedLabels[i][j] = inputLabel
}
}
}
func (e *Evaluator) SetBlob(blob []byte) {
offset := 0
for i := 1; i < len(e.g.Cs); i++ {
ttLen := e.g.Cs[i].Circuit.AndGateCount * 48
olLen := e.g.Cs[i].Circuit.OutputSize * 32
if i == 5 {
ttLen = e.g.C5Count * ttLen
olLen = e.g.C5Count * olLen
}
if i == 6 {
ttLen = e.g.C6Count * ttLen
olLen = e.g.C6Count * olLen
}
e.ttBlobs[i] = blob[offset : offset+ttLen]
offset += ttLen
e.olBlobs[i] = blob[offset : offset+olLen]
offset += olLen
}
if len(blob) != offset {
panic("len(blob) != offset")
}
}
func (e *Evaluator) GetNonFixedIndexes(cNo int) []byte {
c := &e.g.Cs[cNo]
nonFixedBits := c.InputBits[:c.NotaryNonFixedInputSize]
//get OT indexes for bits in the non-fixed inputs
idxArray, otArray := e.DoGetNonFixedIndexes(nonFixedBits)
e.nonFixedOTBits[cNo] = otArray
return idxArray
}
// return indexes from the OT pool as well as OTmap for each OT
func (e *Evaluator) DoGetNonFixedIndexes(bits []int) ([]byte, []OTmap) {
var idxArray []byte //flat array of 2-byte indexes
otArray := make([]OTmap, len(bits))
for i := 0; i < len(bits); i++ {
bit := bits[i]
if bit == 0 {
// take element from the end of slice and shrink slice
ot0 := e.OT0[len(e.OT0)-1]
e.OT0 = e.OT0[:len(e.OT0)-1]
idx := make([]byte, 2)
binary.BigEndian.PutUint16(idx, uint16(ot0.idx))
idxArray = append(idxArray, idx...)
otArray[i] = ot0
} else {
// take element from the end of slice and shrink slice
ot1 := e.OT1[len(e.OT1)-1]
e.OT1 = e.OT1[:len(e.OT1)-1]
idx := make([]byte, 2)
binary.BigEndian.PutUint16(idx, uint16(ot1.idx))
idxArray = append(idxArray, idx...)
otArray[i] = ot1
}
}
if len(e.OT0) < 1 || len(e.OT1) < 1 {
panic("len(e.OT0) < 1 || len(e.OT1) < 1")
}
return idxArray, otArray
}
func (e *Evaluator) Evaluate(cNo int, notaryLabelsBlob, clientLabelsBlob, ttBlob, olBlob []byte) []byte {
type batchType struct {
ga *[][]byte
// Evaluate evaluates a circuit number cNo
func (e *Evaluator) Evaluate(cNo int, notaryLabels, clientLabels,
truthTables, decodingTable []byte) []byte {
type batch_t struct {
// wl is wire labels
wl *[][]byte
// tt is truth tables
tt *[]byte
// dt is decoding table
dt *[]byte
}
c := &e.g.Cs[cNo]
nlBatch := u.SplitIntoChunks(notaryLabelsBlob, c.NotaryInputSize*16)
clBatch := u.SplitIntoChunks(clientLabelsBlob, c.ClientInputSize*16)
ttBatch := u.SplitIntoChunks(ttBlob, c.Circuit.AndGateCount*48)
c := (e.meta)[cNo]
// split into a batch for multiple executions
nlBatch := u.SplitIntoChunks(notaryLabels, c.NotaryInputSize*16)
clBatch := u.SplitIntoChunks(clientLabels, c.ClientInputSize*16)
ttBatch := u.SplitIntoChunks(truthTables, c.AndGateCount*48)
dtBatch := u.SplitIntoChunks(decodingTable, int(math.Ceil(float64(c.OutputSize)/8)))
// exeCount is how many executions of this circuit we need
exeCount := []int{0, 1, 1, 1, 1, e.g.C5Count, 1, e.g.C6Count}[cNo]
batch := make([]batchType, exeCount)
exeCount := []int{0, 1, 1, 1, 1, 1, e.C6Count, 1}[cNo]
batch := make([]batch_t, exeCount)
for r := 0; r < exeCount; r++ {
// put all labels into garbling assignment
ga := make([][]byte, c.Circuit.WireCount)
copy(ga, u.SplitIntoChunks(u.Concat(nlBatch[r], clBatch[r]), 16))
batch[r] = batchType{&ga, &ttBatch[r]}
}
batchOutputLabels := make([][][]byte, exeCount)
for r := 0; r < exeCount; r++ {
evaluate(c.Circuit, batch[r].ga, batch[r].tt)
outputLabels := (*batch[r].ga)[len((*batch[r].ga))-c.Circuit.OutputSize:]
batchOutputLabels[r] = outputLabels
// put all input labels into wire labels
wireLabels := make([][]byte, c.WireCount)
copy(wireLabels, u.SplitIntoChunks(u.Concat(nlBatch[r], clBatch[r]), 16))
batch[r] = batch_t{&wireLabels, &ttBatch[r], &dtBatch[r]}
}
var output []byte
for r := 0; r < exeCount; r++ {
outputLabels := batchOutputLabels[r]
outBits := make([]int, c.Circuit.OutputSize)
outputSizeBytes := c.Circuit.OutputSize * 32
allOutputLabelsBlob := olBlob[r*outputSizeBytes : (r+1)*outputSizeBytes]
for i := 0; i < len(outBits); i++ {
out := outputLabels[i]
if bytes.Equal(out, allOutputLabelsBlob[i*32:i*32+16]) {
outBits[i] = 0
} else if bytes.Equal(out, allOutputLabelsBlob[i*32+16:i*32+32]) {
outBits[i] = 1
} else {
log.Println("incorrect output label")
}
}
outBytes := u.BitsToBytes(outBits)
plaintext := evaluate(c, batch[r].wl, batch[r].tt, batch[r].dt)
// plaintext has a padding in MSB to make it a multiple of 8 bits. We
// decompose into bits and drop the padding
outBits := u.BytesToBits(plaintext)[0:c.OutputSize]
// reverse output bits so that the values of the output be placed in
// the same order as they appear in the *.casm files
outBytes := e.parseOutputBits(cNo, outBits)
output = append(output, outBytes...)
}
c.Output = output
e.CommitHash[cNo] = u.Sha256(c.Output)
e.Salt[cNo] = u.GetRandom(32)
return u.Sha256(u.Concat(e.CommitHash[cNo], e.Salt[cNo]))
return output
}
func evaluate(c *garbler.Circuit, garbledAssignment *[][]byte, tt *[]byte) {
// parseOutputBits converts the output bits of the circuit into a flat slice
// of bytes so that output values are in the same order as they appear in the *.casm files
func (e *Evaluator) parseOutputBits(cNo int, outBits []int) []byte {
o := 0 // offset
var outBytes []byte
for _, v := range (e.meta)[cNo].OutputsSizes {
output := u.BitsToBytes(outBits[o : o+v])
outBytes = append(outBytes, output...)
o += v
}
if o != (e.meta)[cNo].OutputSize {
panic("o != e.g.Cs[cNo].OutputSize")
}
return outBytes
}
func evaluate(c *meta.Circuit, wireLabels *[][]byte, truthTables *[]byte,
decodingTable *[]byte) []byte {
andGateIdx := 0
// gate type XOR==0 AND==1 INV==2
for i := 0; i < len(c.Gates); i++ {
g := c.Gates[i]
if g.Operation == 1 {
evaluateAnd(g, garbledAssignment, tt, andGateIdx)
evaluateAnd(g, wireLabels, truthTables, andGateIdx)
andGateIdx += 1
} else if g.Operation == 0 {
evaluateXor(g, garbledAssignment)
evaluateXor(g, wireLabels)
} else if g.Operation == 2 {
evaluateInv(g, garbledAssignment)
evaluateInv(g, wireLabels)
} else {
panic("Unknown gate")
}
}
// decode output labels
// get decoding table: LSB of label0 for each output wire
outLSBs := make([]int, c.OutputSize)
for i := 0; i < c.OutputSize; i++ {
outLSBs[i] = int((*wireLabels)[c.WireCount-c.OutputSize+i][15]) & 1
}
encodings := u.BitsToBytes(outLSBs)
plaintext := u.XorBytes(*decodingTable, encodings)
return plaintext
}
func evaluateAnd(g garbler.Gate, ga *[][]byte, tt *[]byte, andGateIdx int) {
func evaluateAnd(g meta.Gate, wireLabels *[][]byte, truthTables *[]byte, andGateIdx int) {
// get wire numbers
in1 := g.InputWires[0]
in2 := g.InputWires[1]
out := g.OutputWire
label1 := (*ga)[in1]
label2 := (*ga)[in2]
label1 := (*wireLabels)[in1]
label2 := (*wireLabels)[in2]
var cipher []byte
point := 2*getPoint(label1) + getPoint(label2)
@@ -373,45 +127,24 @@ func evaluateAnd(g garbler.Gate, ga *[][]byte, tt *[]byte, andGateIdx int) {
cipher = make([]byte, 16)
} else {
offset := andGateIdx*48 + 16*point
cipher = (*tt)[offset : offset+16]
cipher = (*truthTables)[offset : offset+16]
}
(*ga)[out] = u.Decrypt(label1, label2, g.Id, cipher)
(*wireLabels)[out] = u.Decrypt(label1, label2, g.Id, cipher)
}
func evaluateXor(g garbler.Gate, ga *[][]byte) {
func evaluateXor(g meta.Gate, wireLabels *[][]byte) {
in1 := g.InputWires[0]
in2 := g.InputWires[1]
out := g.OutputWire
(*ga)[out] = xorBytes((*ga)[in1], (*ga)[in2])
(*wireLabels)[out] = u.XorBytes((*wireLabels)[in1], (*wireLabels)[in2])
}
func evaluateInv(g garbler.Gate, ga *[][]byte) {
func evaluateInv(g meta.Gate, wireLabels *[][]byte) {
in1 := g.InputWires[0]
out := g.OutputWire
(*ga)[out] = (*ga)[in1]
(*wireLabels)[out] = (*wireLabels)[in1]
}
func getPoint(arr []byte) int {
return int(arr[15]) & 0x01
}
func xorBytes(a, b []byte) []byte {
if len(a) != len(b) {
panic("len(a) != len(b)")
}
c := make([]byte, len(a))
for i := 0; i < len(a); i++ {
c[i] = a[i] ^ b[i]
}
return c
}
func isIntInArray(a int, arr []int) bool {
for _, b := range arr {
if b == a {
return true
}
}
return false
}

View File

@@ -1,35 +1,45 @@
package garbled_pool
import (
"encoding/binary"
"io/ioutil"
"log"
"notary/garbler"
"notary/meta"
u "notary/utils"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
)
// gc describes a garbled circuit file
// id is the name of the file
// keyIdx is the index of a key in g.key used to encrypt the gc
// keyIdx is the index of a key in g.keys used to encrypt this gc
type gc struct {
id string
keyIdx int
}
// Blob is what is returned when gc is read from disk
type Blob struct {
Il *[]byte
// we dont return bytes of tt and dt because we gonna be streaming the file
// directly into the HTTP response to save memory
TtFile *os.File
DtFile *os.File
}
type GarbledPool struct {
// gPDirPath is full path to the garbled pool dir
gPDirPath string
// AES-GCM keys to encrypt/authenticate garbled circuits
// we need to encrypt them in case we want to store them outside the enclave
// when the encryption key changes, older keys are kept because we still
// have gc on disk encrypted with old keys
// keysCleanup sets old keys which are not used anymore to nil, thus releasing
// the memory
// AES-GCM keys to encrypt/authenticate circuits' labels.
// We need to encrypt them in case we want to store them outside the enclave.
// When the encryption key changes, older keys are kept because we still
// have labels on disk encrypted with old keys.
// monitor() sets old keys which are not used anymore to nil, thus releasing
// the memory.
keys [][]byte
// key is the current key in use. It is always keys[len(keys)-1]
key []byte
@@ -38,23 +48,16 @@ type GarbledPool struct {
encryptedSoFar int
// we change key after rekeyAfter bytes were encrypted
rekeyAfter int
// c5 subdirs' names are "50, 100, 150 ..." indicating how many garblings of
// a circuit there are in the dir
c5subdirs []string
// pool contains all non-c5 circuits
// pool contains metadata of all circuits. key is circuit number.
pool map[string][]gc
// poolc5 is like pool except map's <key> is one of g.c5subdirs and gc.id
// is a dir containing <key> amount of garblings
poolc5 map[string][]gc
// poolSize is how many concurrent TLSNotary sessions we want to support
// the server will maintain a pool of garbled circuits depending on this value
// the amount of c5 circuits will be poolSize*100 because on average one
// session needs that many garbled c5 circuits
poolSize int
Circuits []*garbler.Circuit
// Circuits's count starts from 1
Circuits []*meta.Circuit
grb garbler.Garbler
// all circuits, count starts with 1 to avoid confusion
Cs []garbler.CData
// noSandbox is set to true when not running in a sandboxed environment
noSandbox bool
sync.Mutex
@@ -65,29 +68,22 @@ func (g *GarbledPool) Init(noSandbox bool) {
g.encryptedSoFar = 0
g.rekeyAfter = 1024 * 1024 * 1024 * 64 // 64GB
g.poolSize = 1
g.pool = make(map[string][]gc, 6)
for _, v := range []string{"1", "2", "3", "4", "5", "6"} {
g.pool = make(map[string][]gc, 7)
for _, v := range []string{"1", "2", "3", "4", "5", "6", "7"} {
g.pool[v] = []gc{}
}
g.Circuits = make([]*garbler.Circuit, 7)
for _, idx := range []int{1, 2, 3, 4, 5, 6} {
g.Circuits[idx] = g.grb.ParseCircuit(idx)
g.Circuits = make([]*meta.Circuit, 8)
for _, idx := range []int{1, 2, 3, 4, 5, 6, 7} {
g.Circuits[idx] = g.parseCircuit(idx)
g.Circuits[idx].OutputsSizes = meta.GetOutputSizes(idx)
}
g.Cs = make([]garbler.CData, 7)
g.Cs[1].Init(512, 512, 512)
g.Cs[2].Init(512, 640, 512)
g.Cs[3].Init(832, 1568, 800)
g.Cs[4].Init(672, 960, 480)
g.Cs[5].Init(160, 308, 128)
g.Cs[6].Init(288, 304, 128)
curDir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
panic(err)
}
g.gPDirPath = filepath.Join(filepath.Dir(curDir), "garbledPool")
if g.noSandbox {
g.key = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6}
} else {
if !g.noSandbox {
// running in an enclave, need to encrypt input labels
g.key = u.GetRandom(16)
}
g.keys = append(g.keys, g.key)
@@ -98,7 +94,7 @@ func (g *GarbledPool) Init(noSandbox bool) {
if err != nil {
panic(err)
}
for _, idx := range []string{"1", "2", "3", "4", "5", "6"} {
for _, idx := range []string{"1", "2", "3", "4", "5", "6", "7"} {
err = os.Mkdir(filepath.Join(g.gPDirPath, "c"+idx), 0755)
if err != nil {
panic(err)
@@ -115,19 +111,19 @@ func (g *GarbledPool) Init(noSandbox bool) {
go g.monitor()
}
// returns Blobs struct for each circuit
func (g *GarbledPool) GetBlobs(c5Count int) []garbler.Blobs {
if c5Count > 1026 {
panic("c5Count > 1026")
// returns 1 garbling of each circuit and c5Count garblings for circuit 5
func (g *GarbledPool) GetBlobs(c6Count int) []Blob {
if c6Count > 1026 {
panic("c6Count > 1026")
}
allBlobs := make([]garbler.Blobs, len(g.Cs))
var allBlobs []Blob
// fetch blobs
for i := 1; i < len(allBlobs); i++ {
for i := 1; i < len(g.Circuits); i++ {
iStr := strconv.Itoa(i)
var count int
if i == 5 {
count = c5Count
if i == 6 {
count = c6Count
} else {
count = 1
}
@@ -144,54 +140,30 @@ func (g *GarbledPool) GetBlobs(c5Count int) []garbler.Blobs {
g.pool[iStr] = g.pool[iStr][1:]
g.Unlock()
blob := g.fetchBlob(iStr, gc)
il, tt, ol := g.deBlob(blob)
allBlobs[i].Il = append(allBlobs[i].Il, il...)
allBlobs[i].Tt = append(allBlobs[i].Tt, tt...)
allBlobs[i].Ol = append(allBlobs[i].Ol, ol...)
allBlobs = append(allBlobs, blob)
}
}
return allBlobs
}
func (g *GarbledPool) loadPoolFromDisk() {
for _, idx := range []string{"1", "2", "3", "4", "5", "6"} {
for _, idx := range []string{"1", "2", "3", "4", "5", "6", "7"} {
files, err := ioutil.ReadDir(filepath.Join(g.gPDirPath, "c"+idx))
if err != nil {
panic(err)
}
var gcs []gc
for _, file := range files {
gcs = append(gcs, gc{id: file.Name(), keyIdx: 0})
if strings.HasSuffix(file.Name(), "_il") {
nameNoSuffix := strings.Split(file.Name(), "_")[0]
gcs = append(gcs, gc{id: nameNoSuffix, keyIdx: 0})
}
}
g.pool[idx] = gcs
log.Println("loaded ", len(g.pool[idx]), " garbled circuits for circuit ", idx)
}
}
// garbles a circuit and returns a blob
func (g *GarbledPool) garbleCircuit(cNo int) []byte {
tt, il, ol, _ := g.grb.OfflinePhase(g.grb.ParseCircuit(cNo), nil, nil, nil)
return g.makeBlob(il, tt, ol)
}
// garbles a batch of count c5 circuits and return the garbled blobs
func (g *GarbledPool) garbleC5Circuits(count int) [][]byte {
var blobs [][]byte
tt, il, ol, R := g.grb.OfflinePhase(g.Circuits[5], nil, nil, nil)
labels := g.grb.SeparateLabels(il, g.Cs[5])
blobs = append(blobs, g.makeBlob(il, tt, ol))
// for all other circuits we only need ClientFixed input labels
ilReused := u.Concat(labels.NotaryFixed, labels.ClientNonFixed)
reuseIndexes := u.ExpandRange(0, 320)
for i := 2; i <= count; i++ {
tt, il, ol, _ := g.grb.OfflinePhase(g.Circuits[5], R, ilReused, reuseIndexes)
labels := g.grb.SeparateLabels(il, g.Cs[5])
blobs = append(blobs, g.makeBlob(labels.ClientFixed, tt, ol))
}
return blobs
}
// monitor replenishes the garbled pool when needed
// and re-keys the encryption key
func (g *GarbledPool) monitor() {
@@ -230,7 +202,7 @@ func (g *GarbledPool) monitor() {
var k string
var v []gc
for k, v = range g.pool {
if k != "5" {
if k != "6" {
if len(v) >= g.poolSize {
continue
} else {
@@ -238,7 +210,8 @@ func (g *GarbledPool) monitor() {
break
}
} else {
// we need at least 1026 garblings for a max TLS record size
// for circuit 6 we need at least 1026 garblings for a max possible
// TLS record size of 16KB
max := u.Max(g.poolSize*100, 1026)
if len(v) >= max {
continue
@@ -249,15 +222,14 @@ func (g *GarbledPool) monitor() {
}
}
// golang doesnt allow to modify map while iterating it
// that's why we break the iteration
// that's why we broke the iteration and got here
if diff > 0 {
// need to replenish the pool
for i := 0; i < diff; i++ {
//log.Println("in monitorPool adding c", k)
kInt, _ := strconv.Atoi(k)
blob := g.garbleCircuit(kInt)
il, tt, dt := g.grb.Garble(g.Circuits[kInt])
randName := u.RandString()
g.saveBlob(filepath.Join(g.gPDirPath, "c"+k, randName), blob)
g.saveBlob(filepath.Join(g.gPDirPath, "c"+k, randName), il, tt, dt)
g.Lock()
g.pool[k] = append(g.pool[k], gc{id: randName, keyIdx: len(g.keys) - 1})
g.Unlock()
@@ -270,87 +242,115 @@ func (g *GarbledPool) monitor() {
}
}
// packs data into a blob with length prefixes
func (g *GarbledPool) makeBlob(il []byte, tt *[]byte, ol []byte) []byte {
ilSize := make([]byte, 4)
binary.BigEndian.PutUint32(ilSize, uint32(len(il)))
ttSize := make([]byte, 4)
binary.BigEndian.PutUint32(ttSize, uint32(len(*tt)))
olSize := make([]byte, 4)
binary.BigEndian.PutUint32(olSize, uint32(len(ol)))
return u.Concat(ilSize, il, ttSize, *tt, olSize, ol)
func (g *GarbledPool) saveBlob(path string, il *[]byte, tt *[]byte, dt *[]byte) {
var ilToWrite *[]byte
// we only encrypt input labels
if !g.noSandbox {
ilEnc := u.AESGCMencrypt(g.key, *il)
ilToWrite = &ilEnc
} else {
ilToWrite = il
}
func (g *GarbledPool) deBlob(blob []byte) ([]byte, []byte, []byte) {
offset := 0
ilSize := int(binary.BigEndian.Uint32(blob[offset : offset+4]))
offset += 4
il := blob[offset : offset+ilSize]
offset += ilSize
ttSize := int(binary.BigEndian.Uint32(blob[offset : offset+4]))
offset += 4
tt := blob[offset : offset+ttSize]
offset += ttSize
olSize := int(binary.BigEndian.Uint32(blob[offset : offset+4]))
offset += 4
ol := blob[offset : offset+olSize]
return il, tt, ol
err := os.WriteFile(path+"_il", *ilToWrite, 0644)
if err != nil {
panic(err)
}
func (g *GarbledPool) saveBlob(path string, blob []byte) {
enc := u.AESGCMencrypt(g.key, blob)
g.encryptedSoFar += len(blob)
err := os.WriteFile(path, enc, 0644)
err = os.WriteFile(path+"_tt", *tt, 0644)
if err != nil {
panic(err)
}
err = os.WriteFile(path+"_dt", *dt, 0644)
if err != nil {
panic(err)
}
}
// fetches the blob from disk and deletes it
func (g *GarbledPool) fetchBlob(circuitNo string, c gc) []byte {
func (g *GarbledPool) fetchBlob(circuitNo string, c gc) Blob {
fullPath := filepath.Join(g.gPDirPath, "c"+circuitNo, c.id)
data, err := os.ReadFile(fullPath)
il, err := os.ReadFile(fullPath + "_il")
if err != nil {
panic(err)
}
err = os.Remove(fullPath)
err = os.Remove(fullPath + "_il")
if err != nil {
panic(err)
}
return u.AESGCMdecrypt(g.keys[c.keyIdx], data)
// only the file handle of truth tables and decoding tables is returned,
// so that the file could be streamed (avoiding a full copy into memory)
// The session which receives this handle will be responsible for
// deleting the file
ttFile, err3 := os.Open(fullPath + "_tt")
if err3 != nil {
panic(err3)
}
dtFile, err4 := os.Open(fullPath + "_dt")
if err4 != nil {
panic(err4)
}
var ilToReturn = &il
if !g.noSandbox {
// decrypt data from disk when in a sandbox
ilDec := u.AESGCMdecrypt(g.keys[c.keyIdx], il)
ilToReturn = &ilDec
}
return Blob{ilToReturn, ttFile, dtFile}
}
// fetches count blobs from folder and then removes it
func (g *GarbledPool) fetchC5Blobs(subdir string, c gc, count int) [][]byte {
var rawBlobs [][]byte
dirPath := filepath.Join(g.gPDirPath, "c5", subdir, c.id)
for i := 0; i < count; i++ {
iStr := strconv.Itoa(i + 1)
data, err := os.ReadFile(filepath.Join(dirPath, iStr))
// Convert the circuits from the "Bristol fashion" format into a compact
// binary representation which can be loaded into RAM and processed gate-by-gate
func (g *GarbledPool) parseCircuit(cNo_ int) *meta.Circuit {
cNo := strconv.Itoa(cNo_)
curDir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
panic(err)
}
rawBlobs = append(rawBlobs, u.AESGCMdecrypt(g.keys[c.keyIdx], data))
}
err := os.RemoveAll(dirPath)
baseDir := filepath.Dir(curDir)
jiggDir := filepath.Join(baseDir, "circuits")
cBytes, err := ioutil.ReadFile(filepath.Join(jiggDir, "c"+cNo+".out"))
if err != nil {
panic(err)
}
return rawBlobs
}
text := string(cBytes)
lines := strings.Split(text, "\n")
c := meta.Circuit{}
wireCount, _ := strconv.ParseInt(strings.Split(lines[0], " ")[1], 10, 32)
gi, _ := strconv.ParseInt(strings.Split(lines[1], " ")[1], 10, 32)
ei, _ := strconv.ParseInt(strings.Split(lines[1], " ")[2], 10, 32)
out, _ := strconv.ParseInt(strings.Split(lines[2], " ")[1], 10, 32)
func (g *GarbledPool) saveC5Blobs(path string, blobs [][]byte) {
err := os.Mkdir(path, 0755)
if err != nil {
panic(err)
c.WireCount = int(wireCount)
c.NotaryInputSize = int(gi)
c.ClientInputSize = int(ei)
c.OutputSize = int(out)
gates := make([]meta.Gate, len(lines)-3)
andGateCount := 0
opBytes := map[string]byte{"XOR": 0, "AND": 1, "INV": 2}
for i, line := range lines[3:] {
items := strings.Split(line, " ")
var g meta.Gate
g.Operation = opBytes[items[len(items)-1]]
g.Id = uint32(i)
if g.Operation == 0 || g.Operation == 1 {
inp1, _ := strconv.ParseInt(items[2], 10, 32)
inp2, _ := strconv.ParseInt(items[3], 10, 32)
out, _ := strconv.ParseInt(items[4], 10, 32)
g.InputWires = []uint32{uint32(inp1), uint32(inp2)}
g.OutputWire = uint32(out)
if g.Operation == 1 {
andGateCount += 1
}
for i := 0; i < len(blobs); i++ {
fileName := strconv.Itoa(i + 1)
enc := u.AESGCMencrypt(g.key, blobs[i])
g.encryptedSoFar += len(blobs[i])
err := os.WriteFile(filepath.Join(path, fileName), enc, 0644)
if err != nil {
panic(err)
} else { // INV gate
inp1, _ := strconv.ParseInt(items[2], 10, 32)
out, _ := strconv.ParseInt(items[3], 10, 32)
g.InputWires = []uint32{uint32(inp1)}
g.OutputWire = uint32(out)
}
gates[i] = g
}
c.Gates = gates
c.AndGateCount = int(andGateCount)
return &c
}

View File

@@ -1,335 +1,186 @@
package garbler
import (
"crypto/rand"
"io/ioutil"
"math/big"
"notary/meta"
u "notary/utils"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/bwesterb/go-ristretto"
)
// Garbler implement the role of the notary as the garbler of the binary circuit.
// This is a fixed-key-cipher garbling scheme BHKR13
// https://eprint.iacr.org/2013/426.pdf
type Garbler struct {
P1_vd []byte // client verify data
Server_verify_data []byte // server verify data
Server_iv, Client_iv []byte
R, One, Zero *big.Int // will be set in Preprocess, used by ghash
// the total amount of c5 circuits for this session
C5Count int
// the total amount of c6 circuits for this session
// the total amount of c6 circuit executions for this session
C6Count int
SwkMaskedByClient []byte
Ot_a *ristretto.Scalar
A *ristretto.Point
AllNonFixedOT [][][]byte
// this is the mask that we apply before sending cwk masked twice to the client
// this is done so that the client could change the mask
cwkSecondMask []byte
CwkMaskedByClient []byte //this is notary's input to c5
// all circuits, count starts with 1 to avoid confusion
Cs []CData
}
// CData is circuit's data
type CData struct {
OT []OTstruct // parsed OT
Ol []byte // output labels
Il []byte // input labels
Tt []byte // truth table
NotaryInputSize int // in bits
NotaryNonFixedInputSize int
NotaryFixedInputSize int
ClientInputSize int // in bits
ClientNonFixedInputSize int
ClientFixedInputSize int
OutputSize int // in bits
Output []byte // garbler+evaluator output of circuit
// InputBits start with least input bit at index [0]
InputBits []int // notary's input for this circuit
PmsOuterHash []byte // only for c1
MsOuterHash []byte // output from c2
Masks [][]byte
Circuit *Circuit
TagSharesBlob []byte
FixedInputs []int // array of 0 and 1 for evaluator's fixed inputs
Meta *meta.Circuit
}
func (p *CData) Init(nis, cis, os int) {
p.NotaryInputSize = nis
p.ClientInputSize = cis
p.OutputSize = os
}
type OTstruct struct {
Ot_a *ristretto.Scalar
A *ristretto.Point
Ot_b *ristretto.Scalar
B *ristretto.Point
AplusB *ristretto.Point
K *ristretto.Point
M0 []byte
M1 []byte
C int
}
type Gate struct {
Id uint32
Operation uint8
InputWires []uint32
OutputWire uint32
}
type Circuit struct {
WireCount int
GarblerInputSize int
EvaluatorInputSize int
OutputSize int
AndGateCount int
Gates []Gate
}
type Labels struct {
NotaryNonFixed []byte
NotaryFixed []byte
ClientNonFixed []byte
ClientFixed []byte
}
type Blobs struct {
Il []byte // input labels
Tt []byte // truth tables
Ol []byte // output labels
}
func (g *Garbler) Init(ilBlobs [][]byte, circuits []*Circuit) {
g.Cs = make([]CData, 7)
g.Cs[1].Init(512, 512, 512)
g.Cs[2].Init(512, 640, 512)
g.Cs[3].Init(832, 1568, 800)
g.Cs[4].Init(672, 960, 480)
g.Cs[5].Init(160, 308, 128)
g.Cs[6].Init(288, 304, 128)
// Init puts input labels into correspondign circuits and creates masks for
// notary's inputs to the circuits.
// ilBlobs contains slices of input labels for each execution
func (g *Garbler) Init(ilBlobs []*[]byte, circuits []*meta.Circuit, c6Count int) {
g.C6Count = c6Count
g.Cs = make([]CData, len(circuits))
for i := 1; i < len(g.Cs); i++ {
c := &g.Cs[i]
c.Il = ilBlobs[i]
c.Circuit = circuits[i]
if i < 6 {
g.Cs[i].Il = *ilBlobs[i-1]
} else if i == 6 {
g.Cs[i].Il = u.ConcatP(ilBlobs[5 : 5+c6Count]...)
} else if i > 6 {
g.Cs[i].Il = *ilBlobs[c6Count-1+i-1]
}
g.Cs[i].Meta = circuits[i]
// mask numbering starts at 1 for convenience
// consult circuits/*.casm files for what each mask does
if i == 1 {
c.Masks = make([][]byte, 2)
c.Masks[1] = u.GetRandom(32)
g.Cs[i].Masks = make([][]byte, 2)
g.Cs[i].Masks[1] = u.GetRandom(32)
}
if i == 2 {
c.Masks = make([][]byte, 2)
c.Masks[1] = u.GetRandom(32)
g.Cs[i].Masks = make([][]byte, 2)
g.Cs[i].Masks[1] = u.GetRandom(32)
}
if i == 3 {
c.Masks = make([][]byte, 7)
c.Masks[1] = u.GetRandom(16)
c.Masks[2] = u.GetRandom(16)
c.Masks[3] = u.GetRandom(4)
c.Masks[4] = u.GetRandom(4)
c.Masks[5] = u.GetRandom(16)
c.Masks[6] = u.GetRandom(16)
g.Cs[i].Masks = make([][]byte, 5)
g.Cs[i].Masks[1] = u.GetRandom(16)
g.Cs[i].Masks[2] = u.GetRandom(16)
g.Cs[i].Masks[3] = u.GetRandom(4)
g.Cs[i].Masks[4] = u.GetRandom(4)
}
if i == 4 {
c.Masks = make([][]byte, 3)
c.Masks[1] = u.GetRandom(16)
c.Masks[2] = u.GetRandom(16)
g.Cs[i].Masks = make([][]byte, 3)
g.Cs[i].Masks[1] = u.GetRandom(16)
g.Cs[i].Masks[2] = u.GetRandom(16)
}
if i == 6 {
c.Masks = make([][]byte, g.C6Count+1)
for j := 1; j < g.C6Count+1; j++ {
c.Masks[j] = u.GetRandom(16)
if i == 5 {
g.Cs[i].Masks = make([][]byte, 3)
g.Cs[i].Masks[1] = u.GetRandom(16)
g.Cs[i].Masks[2] = u.GetRandom(16)
}
if i == 7 {
g.Cs[i].Masks = make([][]byte, 2)
g.Cs[i].Masks[1] = u.GetRandom(16)
}
}
}
// PrepareA is done before Init so that we could send A to the client as soon as possible
func (g *Garbler) PrepareA() {
g.Ot_a = new(ristretto.Scalar).Rand()
g.A = new(ristretto.Point).ScalarMultBase(g.Ot_a)
}
func (g *Garbler) Ot_GetA() []byte {
return g.A.Bytes()
}
// internal method
func (g *Garbler) separateLabels(blob []byte, cNo int) Labels {
c := g.Cs[cNo]
return g.SeparateLabels(blob, c)
}
// separate one continuous blob of input labels into 4 blobs as in Labels struct
func (g *Garbler) SeparateLabels(blob []byte, c CData) Labels {
if len(blob) != (c.NotaryInputSize+c.ClientInputSize)*32 {
panic("in separateLabels")
}
var labels Labels
offset := 0
labels.NotaryNonFixed = make([]byte, c.NotaryNonFixedInputSize*32)
copy(labels.NotaryNonFixed, blob[offset:offset+c.NotaryNonFixedInputSize*32])
offset += c.NotaryNonFixedInputSize * 32
labels.NotaryFixed = make([]byte, c.NotaryFixedInputSize*32)
copy(labels.NotaryFixed, blob[offset:offset+c.NotaryFixedInputSize*32])
offset += c.NotaryFixedInputSize * 32
labels.ClientNonFixed = make([]byte, c.ClientNonFixedInputSize*32)
copy(labels.ClientNonFixed, blob[offset:offset+c.ClientNonFixedInputSize*32])
offset += c.ClientNonFixedInputSize * 32
labels.ClientFixed = make([]byte, c.ClientFixedInputSize*32)
copy(labels.ClientFixed, blob[offset:offset+c.ClientFixedInputSize*32])
offset += c.ClientFixedInputSize * 32
return labels
}
func (g *Garbler) ParseCircuit(cNo_ int) *Circuit {
cNo := strconv.Itoa(cNo_)
curDir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
panic(err)
}
baseDir := filepath.Dir(curDir)
jiggDir := filepath.Join(baseDir, "circuits")
cBytes, err := ioutil.ReadFile(filepath.Join(jiggDir, "c"+cNo+".out"))
if err != nil {
panic(err)
}
text := string(cBytes)
lines := strings.Split(text, "\n")
c := Circuit{}
wireCount, _ := strconv.ParseInt(strings.Split(lines[0], " ")[1], 10, 32)
gi, _ := strconv.ParseInt(strings.Split(lines[1], " ")[1], 10, 32)
ei, _ := strconv.ParseInt(strings.Split(lines[1], " ")[2], 10, 32)
out, _ := strconv.ParseInt(strings.Split(lines[2], " ")[1], 10, 32)
c.WireCount = int(wireCount)
c.GarblerInputSize = int(gi)
c.EvaluatorInputSize = int(ei)
c.OutputSize = int(out)
gates := make([]Gate, len(lines)-3)
andGateCount := 0
opBytes := map[string]byte{"XOR": 0, "AND": 1, "INV": 2}
for i, line := range lines[3:] {
items := strings.Split(line, " ")
var g Gate
g.Operation = opBytes[items[len(items)-1]]
g.Id = uint32(i)
if g.Operation == 0 || g.Operation == 1 {
inp1, _ := strconv.ParseInt(items[2], 10, 32)
inp2, _ := strconv.ParseInt(items[3], 10, 32)
out, _ := strconv.ParseInt(items[4], 10, 32)
g.InputWires = []uint32{uint32(inp1), uint32(inp2)}
g.OutputWire = uint32(out)
if g.Operation == 1 {
andGateCount += 1
}
} else { // INV gate
inp1, _ := strconv.ParseInt(items[2], 10, 32)
out, _ := strconv.ParseInt(items[3], 10, 32)
g.InputWires = []uint32{uint32(inp1)}
g.OutputWire = uint32(out)
}
gates[i] = g
}
c.Gates = gates
c.AndGateCount = int(andGateCount)
return &c
}
// garble a circuit and optionally reuse 1 ) R values 2) inputs with indexes
func (g *Garbler) OfflinePhase(c *Circuit, rReused []byte, inputsReused []byte, reuseIndexes []int) (*[]byte, []byte, []byte, []byte) {
var R []byte
if rReused != nil {
R = rReused
} else {
R = u.GetRandom(16)
// Garble garbles a circuit. Returns input labels, truth tables, decoding table
func (g *Garbler) Garble(c *meta.Circuit) (*[]byte, *[]byte, *[]byte) {
// R is also called the circuit's delta
R := u.GetRandom(16)
// set the last bit of R to 1 for point-and-permute
// this guarantees that 2 labels of the same wire will have the opposite last bits
R[15] = R[15] | 0x01
inputCount := c.ClientInputSize + c.NotaryInputSize
wireLabels := make([][][]byte, c.WireCount)
// put input labels into wire labels
copy(wireLabels, *generateInputLabels(inputCount, R))
// a truth table contains 3 rows 16 bytes each
truthTables := make([]byte, c.AndGateCount*48)
garble(c, &wireLabels, &truthTables, &R)
if len(wireLabels) != c.WireCount {
panic("len(wireLabels) != c.WireCount")
}
if len(reuseIndexes) != len(inputsReused)/32 {
panic("len(reuseIndexes) != len(ilReused)/32")
}
inputCount := c.EvaluatorInputSize + c.GarblerInputSize
//garbled assignment
ga := make([][][]byte, c.WireCount)
newInputs := generateInputLabels(inputCount-len(reuseIndexes), R)
// set both new and reused labels into ga
reusedCount := 0 //how many reused inputs were already put into ga
newInputsCount := 0 //how many new inputs were already put into ga
inputLabels := make([]byte, inputCount*32)
for i := 0; i < inputCount; i++ {
if u.Contains(i, reuseIndexes) {
ga[i] = [][]byte{
inputsReused[reusedCount*32 : reusedCount*32+16],
inputsReused[reusedCount*32+16 : reusedCount*32+32]}
reusedCount += 1
} else {
ga[i] = (*newInputs)[newInputsCount]
newInputsCount += 1
copy(inputLabels[i*32:i*32+16], wireLabels[i][0])
copy(inputLabels[i*32+16:i*32+32], wireLabels[i][1])
}
}
andGateCount := c.AndGateCount
//log.Println("andGateCount is", andGateCount)
truthTable := make([]byte, andGateCount*48)
garble(c, &ga, R, &truthTable)
if len(ga) != c.WireCount {
panic("len(*ga) != c.wireCount")
}
var inputLabels []byte
for i := 0; i < inputCount; i++ {
inputLabels = append(inputLabels, ga[i][0]...)
inputLabels = append(inputLabels, ga[i][1]...)
}
var outputLabels []byte
// get decoding table: LSB of label0 for each output wire
outLSB := make([]int, c.OutputSize)
for i := 0; i < c.OutputSize; i++ {
outputLabels = append(outputLabels, ga[c.WireCount-c.OutputSize+i][0]...)
outputLabels = append(outputLabels, ga[c.WireCount-c.OutputSize+i][1]...)
outLSB[i] = int(wireLabels[c.WireCount-c.OutputSize+i][0][15]) & 1
}
return &truthTable, inputLabels, outputLabels, R
decodingTable := u.BitsToBytes(outLSB)
return &inputLabels, &truthTables, &decodingTable
}
// Client's inputs always come after the Notary's inputs in the circuit
func (g *Garbler) GetClientLabels(cNo int) []byte {
// exeCount is how many executions of this circuit we need
exeCount := []int{0, 1, 1, 1, 1, 1, g.C6Count, 1}[cNo]
c := g.Cs[cNo]
// chunkSize is the bytesize of input labels for one circuit execution
chunkSize := (c.Meta.NotaryInputSize + c.Meta.ClientInputSize) * 32
if chunkSize*exeCount != len(c.Il) {
panic("(chunkSize * exeCount != len(c.Il))")
}
var allIl []byte
for i := 0; i < exeCount; i++ {
allIl = append(allIl, c.Il[i*chunkSize+c.Meta.NotaryInputSize*32:(i+1)*chunkSize]...)
}
return allIl
}
// GetNotaryLabels returns notary's input labels for the circuit
func (g *Garbler) GetNotaryLabels(cNo int) []byte {
// exeCount is how many executions of this circuit we need
exeCount := []int{0, 1, 1, 1, 1, 1, g.C6Count, 1}[cNo]
c := g.Cs[cNo]
// chunkSize is the bytesize of input labels for one circuit execution
chunkSize := (c.Meta.NotaryInputSize + c.Meta.ClientInputSize) * 32
if chunkSize*exeCount != len(c.Il) {
panic("(chunkSize * exeCount != len(c.Il))")
}
var inputLabelBlob []byte
for i := 0; i < exeCount; i++ {
inputLabelBlob = append(inputLabelBlob,
c.Il[i*chunkSize:i*chunkSize+c.Meta.NotaryInputSize*32]...)
}
if len(inputLabelBlob) != len(c.InputBits)*32 {
panic("len(inputLabelBlob) != len(c.InputBits)*32")
}
// pick either label0 or label1 depending on our input bit
var inputLabels []byte
for i := 0; i < len(c.InputBits); i++ {
var label []byte
if c.InputBits[i] == 0 {
label = inputLabelBlob[i*32 : i*32+16]
} else {
label = inputLabelBlob[i*32+16 : i*32+32]
}
inputLabels = append(inputLabels, label...)
}
return inputLabels
}
func generateInputLabels(count int, R []byte) *[][][]byte {
newLabels := make([][][]byte, count)
for i := 0; i < count; i++ {
label1 := make([]byte, 16)
rand.Read(label1)
label1 := u.GetRandom(16)
label2 := u.XorBytes(label1, R)
newLabels[i] = [][]byte{label1, label2}
}
return &newLabels
}
func garble(c *Circuit, garbledAssignment *[][][]byte, R []byte, truthTable *[]byte) {
func garble(c *meta.Circuit, wireLabels *[][][]byte, truthTables *[]byte, R *[]byte) {
var andGateIdx int = 0
// gate type XOR==0 AND==1 INV==2
for i := 0; i < len(c.Gates); i++ {
gate := c.Gates[i]
if gate.Operation == 1 {
tt := garbleAnd(gate, R, garbledAssignment)
copy((*truthTable)[andGateIdx*48:(andGateIdx+1)*48], tt[0:48])
tt := garbleAnd(gate, wireLabels, R)
copy((*truthTables)[andGateIdx*48:(andGateIdx+1)*48], tt[0:48])
andGateIdx += 1
} else if gate.Operation == 0 {
garbleXor(gate, R, garbledAssignment)
garbleXor(gate, wireLabels, R)
} else if gate.Operation == 2 {
garbleInv(gate, garbledAssignment)
garbleInv(gate, wireLabels)
}
}
}
@@ -338,17 +189,17 @@ func getPoint(arr []byte) int {
return int(arr[15]) & 0x01
}
func garbleAnd(g Gate, R []byte, ga *[][][]byte) []byte {
func garbleAnd(g meta.Gate, wireLabels *[][][]byte, R *[]byte) []byte {
// get wire numbers
in1 := g.InputWires[0]
in2 := g.InputWires[1]
out := g.OutputWire
// get labels of each wire
in1_0 := (*ga)[in1][0]
in1_1 := (*ga)[in1][1]
in2_0 := (*ga)[in2][0]
in2_1 := (*ga)[in2][1]
in1_0 := (*wireLabels)[in1][0]
in1_1 := (*wireLabels)[in1][1]
in2_0 := (*wireLabels)[in2][0]
in2_1 := (*wireLabels)[in2][1]
// output wires will be assigned labels later
var out_0, out_1 []byte
@@ -374,16 +225,16 @@ func garbleAnd(g Gate, R []byte, ga *[][][]byte) []byte {
outWire := u.Encrypt(*rows[i][0], *rows[i][1], g.Id, zeroWire)
if i == 3 {
out_1 = outWire
out_0 = u.XorBytes(outWire, R)
out_0 = u.XorBytes(outWire, *R)
} else {
out_0 = outWire
out_1 = u.XorBytes(outWire, R)
out_1 = u.XorBytes(outWire, *R)
}
idxToReduce = i
break
}
}
(*ga)[out] = [][]byte{out_0, out_1}
(*wireLabels)[out] = [][]byte{out_0, out_1}
if idxToReduce == -1 {
panic(idxToReduce == -1)
}
@@ -402,19 +253,18 @@ func garbleAnd(g Gate, R []byte, ga *[][][]byte) []byte {
return u.Flatten(truthTable)
}
func garbleXor(g Gate, R []byte, ga *[][][]byte) {
func garbleXor(g meta.Gate, wireLabels *[][][]byte, R *[]byte) {
in1 := g.InputWires[0]
in2 := g.InputWires[1]
out := g.OutputWire
out1 := u.XorBytes((*ga)[in1][0], (*ga)[in2][0])
out2 := u.XorBytes(u.XorBytes((*ga)[in1][1], (*ga)[in2][1]), R)
(*ga)[out] = [][]byte{out1, out2}
out1 := u.XorBytes((*wireLabels)[in1][0], (*wireLabels)[in2][0])
out2 := u.XorBytes(u.XorBytes((*wireLabels)[in1][1], (*wireLabels)[in2][1]), *R)
(*wireLabels)[out] = [][]byte{out1, out2}
}
func garbleInv(g Gate, ga *[][][]byte) {
func garbleInv(g meta.Gate, wireLabels *[][][]byte) {
in1 := g.InputWires[0]
out := g.OutputWire
(*ga)[out] = [][]byte{(*ga)[in1][1], (*ga)[in1][0]}
(*wireLabels)[out] = [][]byte{(*wireLabels)[in1][1], (*wireLabels)[in1][0]}
}

321
src/ghash/ghash.go Normal file
View File

@@ -0,0 +1,321 @@
package ghash
import (
"log"
"math/big"
u "notary/utils"
)
// Protocol to compute AES-GCM's GHASH function in 2PC using Oblivious Transfer
// https://tlsnotary.org/how_it_works#section4
// (4. Computing MAC of the request using Oblivious Transfer.)
// GHASH implement the 2PC protocol to compute GHASH using OT
type GHASH struct {
// P (p stands for Powers) contains notary's share for each power of H.
// These are powers used to compute MACs/tags on client's requests.
P [][]byte
// maxPowerNeeded is the max power of H that client needs. This equals the
// amount of AES blocks + 2
maxPowerNeeded int
// if we compute all sequential shares of powers from 1 up to and including
// maxOddPowerNeeded, we can start computing the MAC using the Block
// Aggregation method.
maxOddPowerNeeded int
// maxHTable and strategies are initialized in Init(). See comments there.
maxHTable []int
strategy1 [][]int
strategy2 [][]int
}
func (g *GHASH) Init() {
g.P = make([][]byte, 1027) //starting with 1, 1026 is the max that we'll ever need
// maxHTable's <value> shows how many GHASH blocks can be processed
// with Block Aggregation if we have all the sequential shares
// starting with 1 up to and including <key>.
// e.g. {5:29} means that if we have shares of H^1,H^2,H^3,H^4,H^5,
// then we can process 29 GHASH blocks.
// max TLS record size of 16KB requires 1026 GHASH blocks
g.maxHTable = []int{
0: 0, 3: 19, 5: 29, 7: 71, 9: 89, 11: 107, 13: 125, 15: 271, 17: 305, 19: 339, 21: 373,
23: 407, 25: 441, 27: 475, 29: 509, 31: 1023, 33: 1025, 35: 1027}
// shows what shares of powers we will be multiplying to obtain other odd shares of powers
// max sequential odd power that we can obtain during the first round of communication is 19
// note that we multiply N_x*C_y and C_y*N_x to get cross-terms. These are not yet shares of powers
// we must add N_x*N_y and C_x*C_y to respective cross-terms in order to get shares of powers
g.strategy1 = [][]int{
5: {4, 1},
7: {4, 3},
9: {8, 1},
11: {8, 3},
13: {12, 1},
15: {12, 3},
17: {16, 1},
19: {16, 3}}
g.strategy2 = [][]int{
21: {17, 4},
23: {17, 6},
25: {17, 8},
27: {19, 8},
29: {17, 12},
31: {19, 12},
33: {17, 16},
35: {19, 16}}
}
// countPowersToBeMultiplied computes how many consequtive odd powers we need.
// Returns how many block multiplications are needed to obtain those odd powers.
func (g *GHASH) CountPowersToBeMultiplied() int {
totalBlockMult := 0
for k, v := range g.strategy1 {
if v == nil {
continue
}
if k > g.maxOddPowerNeeded {
break
}
totalBlockMult += 2
}
log.Println("totalBlockMult", totalBlockMult)
return totalBlockMult
}
// StepCommon is common to Step1 and Step2, they only differ in the strategy
// used. Notary returns masked xTable for shares of powers based on the strategy
func (g *GHASH) stepCommon(strategy *[][]int) []byte {
var allEntries []byte
for k, v := range *strategy {
if v == nil {
continue
}
if k > g.maxOddPowerNeeded {
break
}
entries1, maskSum1 := GetMaskedXTable(g.P[v[1]])
entries2, maskSum2 := GetMaskedXTable(g.P[v[0]])
allEntries = append(allEntries, entries1...)
allEntries = append(allEntries, entries2...)
// get notary's N_x*N_y and then get the final share of power
NxNy := BlockMult(g.P[v[0]], g.P[v[1]])
g.P[k] = u.XorBytes(u.XorBytes(maskSum1, maskSum2), NxNy)
}
FreeSquare(&g.P, g.maxPowerNeeded)
return allEntries
}
func (g *GHASH) Step1() []byte {
//perform free squaring on powers 2,3 which we have from client finished
FreeSquare(&g.P, g.maxPowerNeeded)
return g.stepCommon(&g.strategy1)
}
func (g *GHASH) Step2() []byte {
return g.stepCommon(&g.strategy2)
}
// in Step3 we multiply GHASH block by those shares of powers which we have.
// For those which we don't have, we perform Block Aggregation.
// Returns 1) Notary's share of GHASH output 2) masked xTables 3) count of block
// multiplications which we performed during Block Aggregation.
func (g *GHASH) Step3(ghashInputs [][]byte) ([]byte, []byte, int) {
u.Assert(len(ghashInputs) == g.maxPowerNeeded)
res := make([]byte, 16)
// compute direct powers
// L is the total count of GHASH blocks. n is the index of the input block
// starting from 0. We multiply GHASH input block X[n] by power H^(L-n).
for i := 1; i < len(g.P); i++ {
if i > g.maxPowerNeeded {
break
}
if g.P[i] == nil {
continue
}
x := ghashInputs[len(ghashInputs)-i]
h := g.P[i]
res = u.XorBytes(res, BlockMult(h, x))
}
// Block Aggregation
// aggregated <key> -> small power, <value> -> aggregated value for that small power
aggregated := make([][]byte, 36) //starting with 1, 35 is the max that we'll ever need
for i := 1; i < len(g.P); i++ {
if i > g.maxPowerNeeded {
break
}
if g.P[i] != nil {
continue
}
// found a hole in our sparse array, need block aggregation
// a is the smaller power
a, b := FindSum(&g.P, i)
x := ghashInputs[len(ghashInputs)-i]
// locally compute a*b*x
res = u.XorBytes(res, BlockMult(BlockMult(g.P[a], g.P[b]), x))
if aggregated[a] == nil {
aggregated[a] = make([]byte, 16) //set to zero
}
aggregated[a] = u.XorBytes(aggregated[a], BlockMult(g.P[b], x))
}
ghashOutputShare := res
// arrange masked Xtable entries for each entry in aggregated:
// first the Xtable for share of the small power,
// then the Xtable for the aggregated value.
var allEntries []byte
maskSum := make([]byte, 16) //starting with zeroed mask
for i := 0; i < len(aggregated); i++ {
if aggregated[i] == nil {
continue
}
entries1, maskSum1 := GetMaskedXTable(g.P[i])
entries2, maskSum2 := GetMaskedXTable(aggregated[i])
allEntries = append(allEntries, entries1...)
allEntries = append(allEntries, entries2...)
maskSum = u.XorBytes(maskSum, u.XorBytes(maskSum1, maskSum2))
}
ghashOutputShare = u.XorBytes(ghashOutputShare, maskSum)
nonNilItemsCount := 0
for i := 0; i < len(aggregated); i++ {
if aggregated[i] != nil {
nonNilItemsCount += 1
}
}
return ghashOutputShare, allEntries, nonNilItemsCount * 2
}
func (g *GHASH) GetMaxPowerNeeded() int {
return g.maxPowerNeeded
}
func (g *GHASH) GetMaxOddPowerNeeded() int {
return g.maxOddPowerNeeded
}
// set max power of H that is needed and calculate max odd power needed based
// on g.maxHTable
func (g *GHASH) SetMaxPowerNeeded(max int) {
g.maxPowerNeeded = max
for k, v := range g.maxHTable {
if v >= g.maxPowerNeeded {
g.maxOddPowerNeeded = k
log.Println("maxPowerNeeded", g.maxPowerNeeded)
log.Println("maxOddPowerNeeded", g.maxOddPowerNeeded)
break
}
}
}
// FreeSquare locally squares all powers found in powersOfH up to and including
// maxPowerNeeded. Modifies powersOfH in place.
func FreeSquare(powersOfH *[][]byte, maxPowerNeeded int) {
for i := 0; i < len(*powersOfH); i++ {
if (*powersOfH)[i] == nil || i%2 == 0 {
continue
}
if i > maxPowerNeeded {
return
}
power := i
for power < maxPowerNeeded {
power = power * 2
if (*powersOfH)[power] != nil {
continue
}
prevPower := (*powersOfH)[power/2]
(*powersOfH)[power] = BlockMult(prevPower, prevPower)
}
}
}
// Galois field multiplication of two 128-bit blocks reduced by the GCM polynomial
func BlockMult(x_, y_ []byte) []byte {
x := new(big.Int).SetBytes(x_)
y := new(big.Int).SetBytes(y_)
res := big.NewInt(0)
_1 := big.NewInt(1)
R, ok := new(big.Int).SetString("E1000000000000000000000000000000", 16)
if !ok {
panic("SetString")
}
for i := 127; i >= 0; i-- {
tmp1 := new(big.Int).Rsh(y, uint(i))
tmp2 := new(big.Int).And(tmp1, _1)
res.Xor(res, new(big.Int).Mul(x, tmp2))
tmp3 := new(big.Int).And(x, _1)
tmp4 := new(big.Int).Mul(tmp3, R)
tmp5 := new(big.Int).Rsh(x, 1)
x = new(big.Int).Xor(tmp5, tmp4)
}
return u.To16Bytes(res)
}
// return a table of byte values of x after each of the 128 rounds of BlockMult
func GetXTable(xBytes []byte) [][]byte {
x := new(big.Int).SetBytes(xBytes)
_1 := big.NewInt(1)
R, ok := new(big.Int).SetString("E1000000000000000000000000000000", 16)
if !ok {
panic("SetString")
}
xTable := make([][]byte, 128)
for i := 0; i < 128; i++ {
xTable[i] = u.To16Bytes(x)
tmp3 := new(big.Int).And(x, _1)
tmp4 := new(big.Int).Mul(tmp3, R)
tmp5 := new(big.Int).Rsh(x, 1)
x = new(big.Int).Xor(tmp5, tmp4)
}
return xTable
}
// FindSum decomposes a sum into non-zero summands. The first summand is repeatedly
// incremented until a suitable second summand is found. Both summands must be
// in the array.
func FindSum(array *[][]byte, sum int) (int, int) {
for i := 0; i < len(*array); i++ {
if (*array)[i] == nil {
continue
}
for j := 0; j < len(*array); j++ {
if (*array)[j] == nil {
continue
}
if i+j == sum {
return i, j
}
}
}
// this should never happen because we always call
// findSum() knowing that the sum can be found
panic("sum not found")
}
// getMaskedXTable returns a masked xTable from which OT response will
// be constructed and the XOR-sum of all masks. A masked xTable replaces
// each entry of xTable with 2 16-byte values: 1) a mask and 2) the xTable
// entry masked with the mask.
func GetMaskedXTable(powerShare []byte) ([]byte, []byte) {
xTable := GetXTable(powerShare)
// maskSum is the xor sum of all masks
maskSum := make([]byte, 16)
var allMessages []byte
for i := 0; i < 128; i++ {
mask := u.GetRandom(16)
maskSum = u.XorBytes(maskSum, mask)
m0 := mask
m1 := u.XorBytes(xTable[i], mask)
allMessages = append(allMessages, m0...)
allMessages = append(allMessages, m1...)
}
return allMessages, maskSum
}

View File

@@ -13,6 +13,12 @@ import (
"time"
)
// KeyManager generates an ephemeral used by notary to sign the session and also
// to derive symmetric keys for client<->notary communication.
// The client only accepts notarization sessions signed by an eph.key whose validity
// interval corresponds to the timestamp of the session.
// We start generating a new eph.key a few minute before the previous key is set to expire.
type KeyManager struct {
sync.Mutex
// Blob contains validFrom|validUntil|pubkey|signature
@@ -34,6 +40,8 @@ func (k *KeyManager) Init() {
go k.rotateEphemeralKeys()
}
// generateMasterKey generates a P-256 master key. The corresponding public key
// in PEM format is written to disk
func (k *KeyManager) generateMasterKey() {
// masterKey is only used to sign ephemeral keys
var err error
@@ -56,10 +64,13 @@ func (k *KeyManager) generateMasterKey() {
// sign it with the master key
func (k *KeyManager) rotateEphemeralKeys() {
k.validMins = 20
// initially setting to zero to immediately trigger a key rotation
nextKeyRotationTime := time.Unix(0, 0)
for {
time.Sleep(time.Second * 1)
now := time.Now()
// start key rotation no sooner than 2 mins before the current eph. key
// is set to expire
if nextKeyRotationTime.Sub(now) > time.Minute*2 {
continue
}

53
src/meta/meta.go Normal file
View File

@@ -0,0 +1,53 @@
// contains various structures with circuit metadata
package meta
// Gate represents a circuit's gate
type Gate struct {
// Id is gate number, Ids start with 0 and increment
Id uint32
// Operation is 0 for XOR, 1 for AND, 2 for INV
Operation uint8
// InputWires is the sequence number of the input wires. Each gate has 1
// (for INV) or 2 (for XOR or AND) wires going into it
InputWires []uint32
// OutputWire is the sequence number of the output wire of this gate.
OutputWire uint32
}
// Circuit contains read only information for each circuit and used both by
// the garbler and the evaluator
type Circuit struct {
// WireCount is total amount of wires in the circuit
WireCount int
// NotaryInputSize is the count of bits in notary's input
NotaryInputSize int
// ClientInputSize is the count of bits in client's input
ClientInputSize int
// OutputSize is the count of bits in the circuit's output
OutputSize int
// AndGateCount the count of AND gates in the circuit
AndGateCount int
// Gates is an array of all gates of the circuit
Gates []Gate
// The output of a circuit is actually multiple concatenated values. We need
// to know how many bits each output value has in order to parse the output
// of all the members of this struct, OutputsSizes is the only one which
// cannot be obtained by parsing the raw circuit. We input this value manually
OutputsSizes []int
}
// GetOutputSizes takes the number of a circuit and returns a slice with
// bit lengths for each of the circuit's output variable.
func GetOutputSizes(idx int) []int {
outputSizes := [][]int{
nil,
[]int{256, 256},
[]int{256, 256},
[]int{128, 128, 32, 32},
[]int{128, 128, 128},
[]int{128, 128, 128, 96},
[]int{128},
[]int{128}}
return outputSizes[idx]
}

View File

@@ -3,23 +3,26 @@ package main
import (
"context"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"sync"
"runtime/debug"
"net/http"
_ "net/http/pprof"
"notary/garbled_pool"
"notary/key_manager"
"notary/session"
"notary/session_manager"
"time"
)
var sm *SessionManager
var sm *session_manager.SessionManager
var gp *garbled_pool.GarbledPool
var km *key_manager.KeyManager
@@ -28,94 +31,23 @@ var km *key_manager.KeyManager
// It contains AWS HTTP API requests with Amazon's attestation
var URLFetcherDoc []byte
type smItem struct {
session *session.Session
lastSeen int64 // timestamp of last activity
creationTime int64 // timestamp
}
type SessionManager struct {
// string looks like 123.123.44.44:23409
sessions map[string]*smItem
sync.Mutex
}
func (sm *SessionManager) Init() {
sm.sessions = make(map[string]*smItem)
go sm.monitorSessions()
}
func (sm *SessionManager) addSession(key string) *session.Session {
if _, ok := sm.sessions[key]; ok {
log.Println(key)
panic("session already exists")
}
s := new(session.Session)
now := int64(time.Now().UnixNano() / 1e9)
sm.Lock()
defer sm.Unlock()
sm.sessions[key] = &smItem{s, now, now}
return s
}
// get an already-existing session associated with the key
// and update the last-seen time
func (sm *SessionManager) getSession(key string) *session.Session {
val, ok := sm.sessions[key]
if !ok {
log.Println(key)
panic("session does not exist")
}
val.lastSeen = int64(time.Now().UnixNano() / 1e9)
return val.session
}
func (sm *SessionManager) removeSession(key string) {
s, ok := sm.sessions[key]
if !ok {
log.Println(key)
panic("cannot remove: session does not exist")
}
err := os.RemoveAll(s.session.StorageDir)
if err != nil {
panic(err)
}
sm.Lock()
defer sm.Unlock()
delete(sm.sessions, key)
}
// remove sessions which have been inactive for 60 sec
func (sm *SessionManager) monitorSessions() {
for {
time.Sleep(time.Second)
now := int64(time.Now().UnixNano() / 1e9)
for k, v := range sm.sessions {
if now-v.lastSeen > 120 || now-v.creationTime > 300 {
log.Println("deleting session from monitorSessions")
sm.removeSession(k)
}
}
}
}
// read request body
// readBody extracts the HTTP request's body
func readBody(req *http.Request) []byte {
defer req.Body.Close()
log.Println("begin ReadAll")
body, err := ioutil.ReadAll(req.Body)
log.Println("finished ReadAll ", len(body))
if err != nil {
panic("can't read request body")
}
return body
}
// writeResponse appends the CORS headers needed to keep the browser happy
// and writes data to the wire
func writeResponse(resp []byte, w http.ResponseWriter) {
//w.Header().Set("Connection", "close")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Content-Type", "application/octet-stream")
w.Write(resp)
log.Println("wrote response of size: ", len(resp))
}
func getURLFetcherDoc(w http.ResponseWriter, req *http.Request) {
@@ -123,38 +55,24 @@ func getURLFetcherDoc(w http.ResponseWriter, req *http.Request) {
writeResponse(URLFetcherDoc, w)
}
func step1(w http.ResponseWriter, req *http.Request) {
log.Println("in step1", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Step1(body)
writeResponse(out, w)
// destroyOnPanic will be called on panic(). It will destroy the session which
// caused the panic
func destroyOnPanic(s *session.Session) {
r := recover()
if r == nil {
return // there was no panic
}
func step2(w http.ResponseWriter, req *http.Request) {
log.Println("in step2", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Step2(body)
writeResponse(out, w)
}
func step3(w http.ResponseWriter, req *http.Request) {
log.Println("in step3", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Step3(body)
writeResponse(out, w)
}
func step4(w http.ResponseWriter, req *http.Request) {
log.Println("in step4", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Step4(body)
writeResponse(out, w)
fmt.Println("caught a panic message: ", r)
debug.PrintStack()
s.DestroyChan <- s.Sid
}
func init1(w http.ResponseWriter, req *http.Request) {
log.Println("in init1", req.RemoteAddr)
s := sm.AddSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
s := sm.addSession(string(req.URL.RawQuery))
s.Gp = gp
// copying data so that it doesn't change from under us if
// ephemeral key happens to change while this session is running
km.Lock()
@@ -162,205 +80,337 @@ func init1(w http.ResponseWriter, req *http.Request) {
copy(blob, km.Blob)
key := *km.PrivKey
km.Unlock()
out := s.Init1(body, blob, key, gp)
out := s.Init1(body, blob, key)
writeResponse(out, w)
}
func init2(w http.ResponseWriter, req *http.Request) {
log.Println("in init2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Init2(body)
out := s.Init2(body)
writeResponse(out, w)
}
func getBlobChunk(w http.ResponseWriter, req *http.Request) {
log.Println("in getBlobChunk", req.RemoteAddr)
func getBlob(w http.ResponseWriter, req *http.Request) {
log.Println("in getBlob", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).GetBlobChunk(body)
tt, dt := s.GetBlob(body)
// send headers first
writeResponse(nil, w)
// stream decoding table directly from file
for _, f := range dt {
_, err := io.Copy(w, f)
if err != nil {
panic("err != nil")
}
}
// stream decoding table directly from file
for _, f := range tt {
_, err := io.Copy(w, f)
if err != nil {
panic("err != nil")
}
}
}
func setBlob(w http.ResponseWriter, req *http.Request) {
log.Println("in setBlob", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
out := s.SetBlob(req.Body)
writeResponse(out, w)
}
func setBlobChunk(w http.ResponseWriter, req *http.Request) {
log.Println("in setBlobChunk", req.RemoteAddr)
func getUploadProgress(w http.ResponseWriter, req *http.Request) {
log.Println("in getUploadProgress", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
out := s.GetUploadProgress()
writeResponse(out, w)
}
func step1(w http.ResponseWriter, req *http.Request) {
log.Println("in step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).SetBlobChunk(body)
out := s.Step1(body)
writeResponse(out, w)
}
func step2(w http.ResponseWriter, req *http.Request) {
log.Println("in step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.Step2(body)
writeResponse(out, w)
}
func step3(w http.ResponseWriter, req *http.Request) {
log.Println("in step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.Step3(body)
writeResponse(out, w)
}
func step4(w http.ResponseWriter, req *http.Request) {
log.Println("in step4", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.Step4(body)
writeResponse(out, w)
}
func c1_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c1_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C1_step1(body)
out := s.C1_step1(body)
writeResponse(out, w)
}
func c1_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c1_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C1_step2(body)
out := s.C1_step2(body)
writeResponse(out, w)
}
func c1_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in c1_step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C1_step3(body)
out := s.C1_step3(body)
writeResponse(out, w)
}
func c1_step4(w http.ResponseWriter, req *http.Request) {
log.Println("in c1_step4", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C1_step4(body)
out := s.C1_step4(body)
writeResponse(out, w)
}
func c1_step5(w http.ResponseWriter, req *http.Request) {
log.Println("in c1_step5", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C1_step5(body)
out := s.C1_step5(body)
writeResponse(out, w)
}
func c2_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c2_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C2_step1(body)
out := s.C2_step1(body)
writeResponse(out, w)
}
func c2_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c2_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C2_step2(body)
out := s.C2_step2(body)
writeResponse(out, w)
}
func c2_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in c2_step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C2_step3(body)
out := s.C2_step3(body)
writeResponse(out, w)
}
func c2_step4(w http.ResponseWriter, req *http.Request) {
log.Println("in c2_step4", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C2_step4(body)
out := s.C2_step4(body)
writeResponse(out, w)
}
func c3_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c3_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C3_step1(body)
out := s.C3_step1(body)
writeResponse(out, w)
}
func c3_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c3_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C3_step2(body)
writeResponse(out, w)
}
func c3_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in c3_step3", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C3_step3(body)
writeResponse(out, w)
}
func c4_pre1(w http.ResponseWriter, req *http.Request) {
log.Println("in c4_pre1", req.RemoteAddr)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C4_pre1(body)
out := s.C3_step2(body)
writeResponse(out, w)
}
func c4_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c4_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C4_step1(body)
out := s.C4_step1(body)
writeResponse(out, w)
}
func c4_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c4_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C4_step2(body)
out := s.C4_step2(body)
writeResponse(out, w)
}
func c4_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in c4_step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C4_step3(body)
out := s.C4_step3(body)
writeResponse(out, w)
}
func c5_pre1(w http.ResponseWriter, req *http.Request) {
log.Println("in c5_pre1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.C5_pre1(body)
writeResponse(out, w)
}
func c5_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c5_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C5_step1(body)
out := s.C5_step1(body)
writeResponse(out, w)
}
func c5_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c5_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C5_step2(body)
out := s.C5_step2(body)
writeResponse(out, w)
}
func c5_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in c5_step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.C5_step3(body)
writeResponse(out, w)
}
func c6_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c6_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C6_step1(body)
out := s.C6_step1(body)
writeResponse(out, w)
}
func c6_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c6_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).C6_step2(body)
out := s.C6_step2(body)
writeResponse(out, w)
}
func checkC6Commit(w http.ResponseWriter, req *http.Request) {
log.Println("in checkC6Commit", req.RemoteAddr)
func c7_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in c7_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).CheckC6Commit(body)
out := s.C7_step1(body)
writeResponse(out, w)
}
func c7_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in c7_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.C7_step2(body)
writeResponse(out, w)
}
func checkC7Commit(w http.ResponseWriter, req *http.Request) {
log.Println("in checkC7Commit", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := s.CheckC7Commit(body)
writeResponse(out, w)
}
func ghash_step1(w http.ResponseWriter, req *http.Request) {
log.Println("in ghash_step1", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Ghash_step1(body)
out := s.Ghash_step1(body)
writeResponse(out, w)
}
func ghash_step2(w http.ResponseWriter, req *http.Request) {
log.Println("in ghash_step2", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Ghash_step2(body)
out := s.Ghash_step2(body)
writeResponse(out, w)
}
func ghash_step3(w http.ResponseWriter, req *http.Request) {
log.Println("in ghash_step3", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).Ghash_step3(body)
out := s.Ghash_step3(body)
writeResponse(out, w)
}
func commitHash(w http.ResponseWriter, req *http.Request) {
log.Println("in commitHash", req.RemoteAddr)
s := sm.GetSession(string(req.URL.RawQuery))
defer destroyOnPanic(s)
body := readBody(req)
out := sm.getSession(string(req.URL.RawQuery)).CommitHash(body)
out := s.CommitHash(body)
writeResponse(out, w)
sm.removeSession(string(req.URL.RawQuery))
s.DestroyChan <- s.Sid
}
// when notary starts we expect the admin to upload a URLFetcher document
@@ -392,10 +442,13 @@ func getPubKey(w http.ResponseWriter, req *http.Request) {
writeResponse(km.MasterPubKeyPEM, w)
}
// initially the circuits are in the human-readable c*.casm format; assemble.js
// converts them into a "Bristol fashion" format and write to disk c*.out files
func assembleCircuits() {
curDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
baseDir := filepath.Dir(curDir)
circuitsDir := filepath.Join(baseDir, "circuits")
// if c1.out does not exist, proceed to assemble
if _, err := os.Stat(filepath.Join(circuitsDir, "c1.out")); os.IsNotExist(err) {
cmd := exec.Command("node", "assemble.js")
cmd.Dir = circuitsDir
@@ -413,6 +466,7 @@ func main() {
// go func() {
// http.ListenAndServe(":8080", nil)
// }()
noSandbox := flag.Bool("no-sandbox", false, "Must be set when not running in a sandboxed environment.")
flag.Parse()
log.Println("noSandbox", *noSandbox)
@@ -420,7 +474,7 @@ func main() {
km = new(key_manager.KeyManager)
km.Init()
assembleCircuits()
sm = new(SessionManager)
sm = new(session_manager.SessionManager)
sm.Init()
gp = new(garbled_pool.GarbledPool)
gp.Init(*noSandbox)
@@ -435,41 +489,57 @@ func main() {
http.HandleFunc("/init1", init1)
http.HandleFunc("/init2", init2)
http.HandleFunc("/getBlobChunk", getBlobChunk)
http.HandleFunc("/setBlobChunk", setBlobChunk)
http.HandleFunc("/getBlob", getBlob)
http.HandleFunc("/setBlob", setBlob)
http.HandleFunc("/getUploadProgress", getUploadProgress)
// step1 thru step4 deal with Paillier 2PC
http.HandleFunc("/step1", step1)
http.HandleFunc("/step2", step2)
http.HandleFunc("/step3", step3)
http.HandleFunc("/step4", step4)
// c1_step1 thru c1_step1 deal with TLS Handshake
http.HandleFunc("/c1_step1", c1_step1)
http.HandleFunc("/c1_step2", c1_step2)
http.HandleFunc("/c1_step3", c1_step3)
http.HandleFunc("/c1_step4", c1_step4)
http.HandleFunc("/c1_step5", c1_step5)
// c2_step1 thru c2_step4 deal with TLS Handshake
http.HandleFunc("/c2_step1", c2_step1)
http.HandleFunc("/c2_step2", c2_step2)
http.HandleFunc("/c2_step3", c2_step3)
http.HandleFunc("/c2_step4", c2_step4)
// c3_step1 thru c4_step3 deal with TLS Handshake and also prepare data
// needed to send Client Finished
http.HandleFunc("/c3_step1", c3_step1)
http.HandleFunc("/c3_step2", c3_step2)
http.HandleFunc("/c3_step3", c3_step3)
http.HandleFunc("/c4_pre1", c4_pre1)
http.HandleFunc("/c4_step1", c4_step1)
http.HandleFunc("/c4_step2", c4_step2)
http.HandleFunc("/c4_step3", c4_step3)
// c5_pre1 thru c5_step3 check Server Finished
http.HandleFunc("/c5_pre1", c5_pre1)
http.HandleFunc("/c5_step1", c5_step1)
http.HandleFunc("/c5_step2", c5_step2)
http.HandleFunc("/c5_step3", c5_step3)
// c6_step1 thru c6_step2 prepare encrypted counter blocks for the
// client's request to the webserver
http.HandleFunc("/c6_step1", c6_step1)
http.HandleFunc("/c6_step2", c6_step2)
http.HandleFunc("/checkC6Commit", checkC6Commit)
// c7_step1 thru c7_step2 prepare the GCTR block needed to compute the MAC
// for the client's request
http.HandleFunc("/c7_step1", c7_step1)
http.HandleFunc("/c7_step2", c7_step2)
http.HandleFunc("/checkC7Commit", checkC7Commit)
// steps ghash_step1 thru ghash_step3 compute the GHASH output needed to
// compute the MAC for the client's request
http.HandleFunc("/ghash_step1", ghash_step1)
http.HandleFunc("/ghash_step2", ghash_step2)
http.HandleFunc("/ghash_step3", ghash_step3)

View File

@@ -1,7 +1,6 @@
package ot
import (
"log"
u "notary/utils"
"github.com/bwesterb/go-ristretto"
@@ -40,7 +39,6 @@ func (o *OTReceiver) SetupStep1() ([]byte, []byte) {
return o.A.Bytes(), seedCommit
}
// Step 3
func (o *OTReceiver) SetupStep2(allBsBlob, senderSeedShare []byte) ([]byte, []byte, []byte, []byte) {
// compute key_0 and key_1 for each B of the base OT
if (len(allBsBlob) != 128*32) || (len(senderSeedShare) != 16) {
@@ -95,15 +93,13 @@ func (o *OTReceiver) SetupStep2(allBsBlob, senderSeedShare []byte) ([]byte, []by
// now we have instances of Random OT where depending on r's bit,
// each row in RT0 equals to a row either in RQ0 or RQ1
log.Println("done 2")
// use Beaver Derandomization [Beaver91] to convert randomOT into standardOT
return u.Concat(encryptedColumns...), o.seedShare, x, t
}
// Step 5
// request Oblivious Transfer from the Sender for the choice bits
func (o *OTReceiver) RequestMaskedOT(bitsArr []int) []byte {
if o.receivedSoFar+len(bitsArr) > o.otCount {
// CreateRequest creates a request for OT for the choice bits.
func (o *OTReceiver) CreateRequest(choiceBits []int) []byte {
if o.receivedSoFar+len(choiceBits) > o.otCount {
panic("o.receivedSoFar + len(bitsArr) > o.otCount")
}
if o.expectingResponseSize != 0 {
@@ -113,41 +109,42 @@ func (o *OTReceiver) RequestMaskedOT(bitsArr []int) []byte {
// no flip needed, 1 means a flip is needed
// pad the bitcount to a multiple of 8
dropCount := 0
if len(bitsArr)%8 > 0 {
dropCount = 8 - len(bitsArr)%8
if len(choiceBits)%8 > 0 {
dropCount = 8 - len(choiceBits)%8
}
bitsToFlip := make([]int, len(bitsArr)+dropCount)
for i := 0; i < len(bitsArr); i++ {
bitsToFlip[i] = bitsArr[i] ^ o.rBits[o.receivedSoFar+i]
bitsToFlip := make([]int, len(choiceBits)+dropCount)
for i := 0; i < len(choiceBits); i++ {
bitsToFlip[i] = choiceBits[i] ^ o.rBits[o.receivedSoFar+i]
}
for i := 0; i < dropCount; i++ {
bitsToFlip[len(bitsArr)+i] = 0
bitsToFlip[len(choiceBits)+i] = 0
}
o.expectingResponseSize = len(bitsArr)
o.expectingResponseSize = len(choiceBits)
// prefix with the amount of bits that Sender needs to drop
// in cases when bitsArr.length is not a multiple of 8
return u.Concat([]byte{byte(dropCount)}, u.BitsToBytes(bitsToFlip))
}
// Step 7
// for every choice bit in bitsArr, unmask one of the two 16-byte messages
func (o *OTReceiver) UnmaskOT(bitsArr []int, encodedOT []byte) []byte {
if (o.expectingResponseSize != len(bitsArr)) ||
// ParseResponse parses (i.e. decodes) the OT response from the OT sender and
// returns the plaintext result of OT.
// For every choice bit, it unmasks one of the two 16-byte messages.
func (o *OTReceiver) ParseResponse(choiceBits []int, encodedOT []byte) []byte {
if (o.expectingResponseSize != len(choiceBits)) ||
(o.expectingResponseSize*32 != len(encodedOT)) {
panic("o.expectingResponseSize issue")
}
decodedArr := make([][]byte, len(bitsArr))
for i := 0; i < len(bitsArr); i++ {
decodedArr := make([][]byte, len(choiceBits))
for i := 0; i < len(choiceBits); i++ {
mask := o.RT0[(o.receivedSoFar+i)*16 : (o.receivedSoFar+i)*16+16]
m0 := encodedOT[i*32 : i*32+16]
m1 := encodedOT[i*32+16 : i*32+32]
if bitsArr[i] == 0 {
if choiceBits[i] == 0 {
decodedArr[i] = u.XorBytes(m0, mask)
} else {
decodedArr[i] = u.XorBytes(m1, mask)
}
}
o.receivedSoFar += len(bitsArr)
o.receivedSoFar += len(choiceBits)
o.expectingResponseSize = 0
return u.Concat(decodedArr...)
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/bwesterb/go-ristretto"
)
// OTSender implements the sender of the Oblivious Transfer acc.to.
// the KOS15 protocol
type OTSender struct {
extraOT int
otCount int
@@ -38,7 +40,7 @@ func (o *OTSender) SetupStep1(A_, hisCommit []byte) ([]byte, []byte) {
}
o.hisCommit = hisCommit
o.seedShare = u.GetRandom(16)
// Alice computes her Bs and decryption keys based on each bit in S
// compute Bs and decryption keys of the base OT for each bit in S
o.s = u.GetRandom(16)
o.sBits = u.Reverse(u.BytesToBits(o.s))
allBs := make([][]byte, len(o.sBits))
@@ -112,19 +114,21 @@ func (o *OTSender) SetupStep2(encryptedColumnsBlob, receiverSeedShare, x, t []by
o.rQ1 = breakCorrelation(Q1[0 : len(Q1)-o.extraOT])
// now we have instances of Random OT where depending on r's bit,
// each row in RT0 equals to a row either in RQ0 or RQ1
log.Println("done 2")
// in Steps 5,6,7 we will use Beaver Derandomization to convert
// when processing OT request, we will use Beaver Derandomization to convert
// randomOT into standardOT
}
// Step 6
// for every bit in bitsToFlip, the Sender has two 16-byte messages for 1-of-2 OT and
// two random masks (from the KOS15 protocol) r0 and r1
// if the bit is 0, the Sender sends (m0 xor r0) and (m1 xor r1),
// if the bit is 1, the Sender sends (m0 xor r1) and (m1 xor r0)
func (o *OTSender) GetMaskedOT(bitsBlob, messages []byte) []byte {
dropCount := int(bitsBlob[0])
bitsToFlipWithRem := u.BytesToBits(bitsBlob[1:])
// ProcessRequest processes a request for OT from the OT receiver.
// otRequest contains bits which need to be flipped acc.to the Beaver derandomiation
// method. The Sender has two 16-byte messages for 1-of-2 OT and
// two random masks (from the KOS15 protocol) r0 and r1.
// If the bit to flip is 0, the Sender sends (m0 xor r0) and (m1 xor r1).
// If the bit to flip is 1, the Sender sends (m0 xor r1) and (m1 xor r0).
// Returns an OT response.
func (o *OTSender) ProcessRequest(otRequest, messages []byte) []byte {
dropCount := int(otRequest[0])
bitsToFlipWithRem := u.BytesToBits(otRequest[1:])
bitsToFlip := bitsToFlipWithRem[:len(bitsToFlipWithRem)-dropCount]
if o.sentSoFar+len(bitsToFlip) > o.otCount {
panic("o.sentSoFar + len(bitsToFlip) > o.otCount")

View File

@@ -14,14 +14,21 @@ import (
// Protocol to compute EC point addition in Paillier as described here:
// https://tlsnotary.org/how_it_works#section1
// The code uses the same notation as in the link above.
// The code uses the same notation as in the link above. The code must be read
// alongside the writeup.
// Paillier2PC implements the notary's side of computing an EC point
// addition in 2PC
type Paillier2PC struct {
p256 ec.Curve
d_n, Q_nx, Q_ny *big.Int
// d_n is notary's share of the EC private key
d_n *big.Int
// Q_nx, Q_ny are notary's shares of the EC public key
Q_nx, Q_ny *big.Int
// paillierPrivKey is used to decrypt 2PC messages from client
// it also contains a public key used to encrypt 2PC message to the client
paillierPrivKey *paillier.PrivateKey
// constant numbers
Zero, One, Two, Three *big.Int
// P is curve P-256's Field prime
P *big.Int
@@ -141,6 +148,7 @@ func (p *Paillier2PC) Step3(payload []byte) []byte {
return []byte(json)
}
// final step
func (p *Paillier2PC) Step4(payload []byte) []byte {
type Step4 struct {
E135 string

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,114 @@
package session_manager
import (
"log"
"notary/session"
"os"
"sync"
"time"
)
// smItem is stored internally by SessionManager
type smItem struct {
session *session.Session
lastSeen int64 // timestamp of last activity
creationTime int64 // timestamp
}
// SessionManager manages TLSNotary sessions from multiple users. When a user
// sends a request, SessionManager extracts the unique id of the user from the
// request, and calls the matching session.
type SessionManager struct {
// string looks like 123.123.44.44:23409
sessions map[string]*smItem
destroyChan chan string
sync.Mutex
}
func (sm *SessionManager) Init() {
sm.sessions = make(map[string]*smItem)
go sm.monitorSessions()
sm.destroyChan = make(chan string)
go sm.monitorDestroyChan()
}
// addSession creates a new session and sets its creation time
func (sm *SessionManager) AddSession(key string) *session.Session {
if _, ok := sm.sessions[key]; ok {
log.Println("Error: session already exists ", key)
}
s := new(session.Session)
s.Sid = key
s.DestroyChan = sm.destroyChan
now := int64(time.Now().UnixNano() / 1e9)
sm.Lock()
defer sm.Unlock()
sm.sessions[key] = &smItem{s, now, now}
return s
}
// get an already-existing session associated with the key
// and update the last-seen time
func (sm *SessionManager) GetSession(key string) *session.Session {
val, ok := sm.sessions[key]
if !ok {
log.Println("Error: the requested session does not exist ", key)
return nil
}
val.lastSeen = int64(time.Now().UnixNano() / 1e9)
return val.session
}
// removeSession removes the session and associated storage data
func (sm *SessionManager) removeSession(key string) {
s, ok := sm.sessions[key]
if !ok {
log.Println("Cannot remove: session does not exist ", key)
}
err := os.RemoveAll(s.session.StorageDir)
if err != nil {
log.Println("Error while removing session ", key)
log.Println(err)
}
for _, f := range s.session.Tt {
err = os.Remove(f.Name())
if err != nil {
log.Println("Error while removing session ", key)
log.Println(err)
}
}
for _, f := range s.session.Dt {
err = os.Remove(f.Name())
if err != nil {
log.Println("Error while removing session ", key)
log.Println(err)
}
}
sm.Lock()
defer sm.Unlock()
delete(sm.sessions, key)
}
// monitorSessions removes sessions which have been inactive or which have
// been too long-running
func (sm *SessionManager) monitorSessions() {
for {
time.Sleep(time.Second)
now := int64(time.Now().UnixNano() / 1e9)
for k, v := range sm.sessions {
if now-v.lastSeen > 120 || now-v.creationTime > 300 {
log.Println("will remove stale session ", k)
sm.removeSession(k)
}
}
}
}
// monitorDestroyChan waits on a chan for a signal from a session to destroy it
func (sm *SessionManager) monitorDestroyChan() {
for {
sid := <-sm.destroyChan
log.Println("monitorDestroyChan will destroy sid: ", sid)
sm.removeSession(sid)
}
}

View File

@@ -39,6 +39,12 @@ func SplitIntoChunks(data []byte, chunkSize int) [][]byte {
return chunks
}
func Assert(condition bool) {
if !condition {
panic("assert failed")
}
}
// port of sodium.crypto_generichash
func Generichash(length int, msg []byte) []byte {
h, err := blake2b.New(length, nil)
@@ -144,7 +150,7 @@ func BytesToBits(b []byte) []int {
return bits
}
// convert an array of 0/1 into bytes
// convert an array of 0/1 with least bit at index 0 into bytes
func BitsToBytes(b []int) []byte {
bigint := new(big.Int)
for i := 0; i < len(b); i++ {
@@ -182,9 +188,24 @@ func Concat(slices ...[]byte) []byte {
return newSlice
}
// concatenate slices of bytes pointed to by pointers into a new slice with
// a new underlying array
func ConcatP(pointers ...*[]byte) []byte {
totalSize := 0
for _, v := range pointers {
totalSize += len(*v)
}
newSlice := make([]byte, totalSize)
copiedSoFar := 0
for _, v := range pointers {
copy(newSlice[copiedSoFar:copiedSoFar+len(*v)], *v)
copiedSoFar += len(*v)
}
return newSlice
}
// finishes sha256 hash from a previous mid-state
func FinishHash(outerState []byte, data []byte) []byte {
digest := sha256.New()
digestUnmarshaler, ok := digest.(encoding.BinaryUnmarshaler)
if !ok {
@@ -207,105 +228,7 @@ func FinishHash(outerState []byte, data []byte) []byte {
return digest.Sum(nil)
}
// GF block multiplication
func BlockMultOld(val, encZero *big.Int) *big.Int {
res := big.NewInt(0)
_255 := big.NewInt(255)
R, ok := new(big.Int).SetString("E1000000000000000000000000000000", 16)
if !ok {
panic("SetString")
}
j := new(big.Int)
for i := 0; i < 16; i++ {
j.And(val, _255)
j.Lsh(j, uint(8*i))
res.Xor(res, gf_2_128_mul(encZero, j, R))
val.Rsh(val, 8) // val >>= 8n
}
return res
}
// Galois field multiplication of two 128-bit blocks reduced by the GCM polynomial
func BlockMult(x_, y_ []byte) []byte {
x := new(big.Int).SetBytes(x_)
y := new(big.Int).SetBytes(y_)
res := big.NewInt(0)
_1 := big.NewInt(1)
R, ok := new(big.Int).SetString("E1000000000000000000000000000000", 16)
if !ok {
panic("SetString")
}
for i := 127; i >= 0; i-- {
tmp1 := new(big.Int).Rsh(y, uint(i))
tmp2 := new(big.Int).And(tmp1, _1)
res.Xor(res, new(big.Int).Mul(x, tmp2))
tmp3 := new(big.Int).And(x, _1)
tmp4 := new(big.Int).Mul(tmp3, R)
tmp5 := new(big.Int).Rsh(x, 1)
x = new(big.Int).Xor(tmp5, tmp4)
}
return To16Bytes(res)
}
// return a table of byte values of x after each of the 128 rounds of BlockMult
func GetXTable(xBytes []byte) [][]byte {
x := new(big.Int).SetBytes(xBytes)
_1 := big.NewInt(1)
R, ok := new(big.Int).SetString("E1000000000000000000000000000000", 16)
if !ok {
panic("SetString")
}
xTable := make([][]byte, 128)
for i := 0; i < 128; i++ {
xTable[i] = To16Bytes(x)
tmp3 := new(big.Int).And(x, _1)
tmp4 := new(big.Int).Mul(tmp3, R)
tmp5 := new(big.Int).Rsh(x, 1)
x = new(big.Int).Xor(tmp5, tmp4)
}
return xTable
}
func FindSum(powersOfH *[][]byte, sum int) (int, int) {
for i := 0; i < len(*powersOfH); i++ {
if (*powersOfH)[i] == nil {
continue
}
for j := 0; j < len(*powersOfH); j++ {
if (*powersOfH)[j] == nil {
continue
}
if i+j == sum {
return i, j
}
}
}
// this should never happen because we always call
// findSum() knowing that the sum can be found
panic("sum not found")
}
// returns modified powersOfH
func FreeSquare(powersOfH *[][]byte, maxPowerNeeded int) {
for i := 0; i < len(*powersOfH); i++ {
if (*powersOfH)[i] == nil || i%2 == 0 {
continue
}
if i > maxPowerNeeded {
return
}
power := i
for power < maxPowerNeeded {
power = power * 2
if (*powersOfH)[power] != nil {
continue
}
prevPower := (*powersOfH)[power/2]
(*powersOfH)[power] = BlockMult(prevPower, prevPower)
}
}
}
// GetRandom returns a random slice of specified size
func GetRandom(size int) []byte {
randomBytes := make([]byte, size)
_, err := rand.Read(randomBytes)