mirror of
https://github.com/tlsnotary/tlsn.git
synced 2026-01-11 07:37:58 -05:00
Compare commits
101 Commits
v0.1.0-alp
...
interactiv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8dd5f8c9a2 | ||
|
|
75f5690f74 | ||
|
|
3b233c18c2 | ||
|
|
ef6180c313 | ||
|
|
1a80ef75f8 | ||
|
|
f2ff4ba792 | ||
|
|
9bf3371873 | ||
|
|
ba0056c8db | ||
|
|
c790b2482a | ||
|
|
0db2e6b48f | ||
|
|
9d853eb496 | ||
|
|
0c2cc6e466 | ||
|
|
6923ceefd3 | ||
|
|
5239c2328a | ||
|
|
2c048e92ed | ||
|
|
6c8cf8a182 | ||
|
|
59781d1293 | ||
|
|
43fd4d34b5 | ||
|
|
6a7c5384a9 | ||
|
|
7e469006c0 | ||
|
|
4a604c98ce | ||
|
|
55091b5e94 | ||
|
|
126ba26648 | ||
|
|
2e571b0684 | ||
|
|
4e0141f993 | ||
|
|
c9ca87f4b4 | ||
|
|
bc1eba18c9 | ||
|
|
c128ab16ce | ||
|
|
a87125ff88 | ||
|
|
0933d711d2 | ||
|
|
7e631de84a | ||
|
|
79c230f2fa | ||
|
|
345d5d45ad | ||
|
|
02cdbb8130 | ||
|
|
55a26aad77 | ||
|
|
1132d441e1 | ||
|
|
fa2fdfd601 | ||
|
|
24e10d664f | ||
|
|
c0e084c1ca | ||
|
|
b6845dfc5c | ||
|
|
31def9ea81 | ||
|
|
878fe7e87d | ||
|
|
3348ac34b6 | ||
|
|
82767ca2d5 | ||
|
|
c9aaf2e0fa | ||
|
|
241ed3b5a3 | ||
|
|
56f088db7d | ||
|
|
f5250479bd | ||
|
|
0e2eabb833 | ||
|
|
ad530ca500 | ||
|
|
8b1cac6fe0 | ||
|
|
555f65e6b2 | ||
|
|
046485188c | ||
|
|
db53814ee7 | ||
|
|
d924bd6deb | ||
|
|
b3558bef9c | ||
|
|
33c4b9d16f | ||
|
|
edc2a1783d | ||
|
|
c2a6546deb | ||
|
|
2dfa386415 | ||
|
|
5a188e75c7 | ||
|
|
a8bf1026ca | ||
|
|
f900fc51cd | ||
|
|
6ccf102ec8 | ||
|
|
2c500b13bd | ||
|
|
2da0c242cb | ||
|
|
798c22409a | ||
|
|
3b5ac20d5b | ||
|
|
a063f8cc14 | ||
|
|
6f6b24e76c | ||
|
|
a28718923b | ||
|
|
19447aabe5 | ||
|
|
8afb7a4c11 | ||
|
|
43c6877ec0 | ||
|
|
39e14949a0 | ||
|
|
31f62982b5 | ||
|
|
6623734ca0 | ||
|
|
41e215f912 | ||
|
|
9e0f79125b | ||
|
|
7bdd3a724b | ||
|
|
baa486ccfd | ||
|
|
de7a47de5b | ||
|
|
3a57134b3a | ||
|
|
86fed1a90c | ||
|
|
82964c273b | ||
|
|
81aaa338e6 | ||
|
|
f331a7a3c5 | ||
|
|
adb407d03b | ||
|
|
3e54119867 | ||
|
|
71aa90de88 | ||
|
|
93535ca955 | ||
|
|
a34dd57926 | ||
|
|
92d7b59ee8 | ||
|
|
c8e9cb370e | ||
|
|
4dc5570a31 | ||
|
|
198e24c5e4 | ||
|
|
f16d7238e5 | ||
|
|
9253adaaa4 | ||
|
|
8c889ac498 | ||
|
|
f0e2200d22 | ||
|
|
224e41a186 |
@@ -1,10 +1,9 @@
|
||||
[build]
|
||||
target = "wasm32-unknown-unknown"
|
||||
|
||||
[target.wasm32-unknown-unknown]
|
||||
rustflags = [
|
||||
"-C",
|
||||
"target-feature=+atomics,+bulk-memory,+mutable-globals",
|
||||
"-A",
|
||||
"unused_qualifications"
|
||||
]
|
||||
|
||||
[unstable]
|
||||
3
.github/codecov.yml
vendored
Normal file
3
.github/codecov.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
github_checks:
|
||||
annotations: false
|
||||
comment: false
|
||||
43
.github/scripts/gramine.sh
vendored
43
.github/scripts/gramine.sh
vendored
@@ -1,43 +0,0 @@
|
||||
#/bin/sh
|
||||
# this is to be ran in a docker container via an github action that has gramine set-up already e.g.,
|
||||
# notaryserverbuilds.azurecr.io/builder/gramine
|
||||
# with sgx hardware:
|
||||
# ./gramine.sh sgx
|
||||
#
|
||||
# without:
|
||||
# ./gramine.sh
|
||||
##
|
||||
|
||||
if [ -z "$1" ]
|
||||
then
|
||||
run='gramine-direct notary-server &'
|
||||
|
||||
else
|
||||
run='gramine-sgx notary-server &'
|
||||
fi
|
||||
|
||||
|
||||
|
||||
curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||
. "$HOME/.cargo/env"
|
||||
apt install libssl-dev
|
||||
|
||||
gramine-sgx-gen-private-key
|
||||
SGX=1 make
|
||||
gramine-sgx-sign -m notary-server.manifest -o notary-server.sgx
|
||||
mr_enclave=$(gramine-sgx-sigstruct-view --verbose --output-format=json notary-server.sig |jq .mr_enclave)
|
||||
echo "mrenclave=$mr_enclave" >> "$GITHUB_OUTPUT"
|
||||
echo "#### sgx mrenclave" | tee >> $GITHUB_STEP_SUMMARY
|
||||
echo "\`\`\`${mr_enclave}\`\`\`" | tee >> $GITHUB_STEP_SUMMARY
|
||||
eval "$run"
|
||||
sleep 5
|
||||
|
||||
if [ "$1" ]; then
|
||||
curl 127.0.0.1:7047/info
|
||||
else
|
||||
quote=$(curl 127.0.0.1:7047/info | jq .quote.rawQuote)
|
||||
echo $quote
|
||||
echo "quote=$quote" >> $GITHUB_OUTPUT
|
||||
echo "#### 🔒 signed quote ${quote}" | tee >> $GITHUB_STEP_SUMMARY
|
||||
echo "${quote}" | tee >> $GITHUB_STEP_SUMMARY
|
||||
fi
|
||||
96
.github/workflows/ci.yml
vendored
96
.github/workflows/ci.yml
vendored
@@ -23,6 +23,7 @@ env:
|
||||
# 32 seems to be big enough for the foreseeable future
|
||||
RAYON_NUM_THREADS: 32
|
||||
GIT_COMMIT_HASH: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
RUST_VERSION: 1.87.0
|
||||
|
||||
jobs:
|
||||
clippy:
|
||||
@@ -32,17 +33,17 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install stable rust toolchain
|
||||
- name: Install rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
toolchain: ${{ env.RUST_VERSION }}
|
||||
components: clippy
|
||||
|
||||
- name: Use caching
|
||||
uses: Swatinem/rust-cache@v2.7.7
|
||||
|
||||
- name: Clippy
|
||||
run: cargo clippy --keep-going --all-features --all-targets -- -D warnings
|
||||
run: cargo clippy --keep-going --all-features --all-targets --locked -- -D warnings
|
||||
|
||||
fmt:
|
||||
name: Check formatting
|
||||
@@ -71,19 +72,19 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install stable rust toolchain
|
||||
- name: Install rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
toolchain: ${{ env.RUST_VERSION }}
|
||||
|
||||
- name: Use caching
|
||||
uses: Swatinem/rust-cache@v2.7.7
|
||||
|
||||
- name: Build
|
||||
run: cargo build --all-targets
|
||||
run: cargo build --all-targets --locked
|
||||
|
||||
- name: Test
|
||||
run: cargo test --no-fail-fast
|
||||
run: cargo test --no-fail-fast --locked
|
||||
|
||||
wasm:
|
||||
name: Build and Test wasm
|
||||
@@ -92,11 +93,11 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install stable rust toolchain
|
||||
- name: Install rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
targets: wasm32-unknown-unknown
|
||||
toolchain: stable
|
||||
toolchain: ${{ env.RUST_VERSION }}
|
||||
|
||||
- name: Install nightly rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
@@ -116,20 +117,23 @@ jobs:
|
||||
- name: Use caching
|
||||
uses: Swatinem/rust-cache@v2.7.7
|
||||
|
||||
- name: Build harness
|
||||
working-directory: crates/harness
|
||||
run: ./build.sh
|
||||
|
||||
- name: Run tests
|
||||
working-directory: crates/harness
|
||||
run: |
|
||||
cd crates/wasm-test-runner
|
||||
./run.sh
|
||||
./bin/runner setup
|
||||
./bin/runner --target browser test
|
||||
|
||||
- name: Run build
|
||||
run: |
|
||||
cd crates/wasm
|
||||
./build.sh
|
||||
working-directory: crates/wasm
|
||||
run: ./build.sh
|
||||
|
||||
- name: Dry Run NPM Publish
|
||||
run: |
|
||||
cd crates/wasm/pkg
|
||||
npm publish --dry-run
|
||||
working-directory: crates/wasm/pkg
|
||||
run: npm publish --dry-run
|
||||
|
||||
- name: Save tlsn-wasm package for tagged builds
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
@@ -146,10 +150,10 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install stable rust toolchain
|
||||
- name: Install rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
toolchain: ${{ env.RUST_VERSION }}
|
||||
|
||||
- name: Use caching
|
||||
uses: Swatinem/rust-cache@v2.7.7
|
||||
@@ -158,7 +162,7 @@ jobs:
|
||||
run: echo "127.0.0.1 tlsnotaryserver.io" | sudo tee -a /etc/hosts
|
||||
|
||||
- name: Run integration tests
|
||||
run: cargo test --profile tests-integration --workspace --exclude tlsn-tls-client --exclude tlsn-tls-core --no-fail-fast -- --include-ignored
|
||||
run: cargo test --locked --profile tests-integration --workspace --exclude tlsn-tls-client --exclude tlsn-tls-core --no-fail-fast -- --include-ignored
|
||||
|
||||
coverage:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -166,14 +170,14 @@ jobs:
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install stable rust toolchain
|
||||
- name: Install rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
toolchain: ${{ env.RUST_VERSION }}
|
||||
- name: Install cargo-llvm-cov
|
||||
uses: taiki-e/install-action@cargo-llvm-cov
|
||||
- name: Generate code coverage
|
||||
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
|
||||
run: cargo llvm-cov --all-features --workspace --locked --lcov --output-path lcov.info
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
@@ -201,7 +205,7 @@ jobs:
|
||||
|
||||
- name: Build Rust Binary
|
||||
run: |
|
||||
cargo build --bin notary-server --release --features tee_quote
|
||||
cargo build --locked --bin notary-server --release --features tee_quote
|
||||
cp --verbose target/release/notary-server $GITHUB_WORKSPACE
|
||||
|
||||
- name: Upload Binary for use in the Gramine Job
|
||||
@@ -214,9 +218,9 @@ jobs:
|
||||
gramine-sgx:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-sgx
|
||||
environment: tee
|
||||
container:
|
||||
image: gramineproject/gramine:latest
|
||||
if: github.ref == 'refs/heads/dev' || (startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '.'))
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -280,8 +284,7 @@ jobs:
|
||||
crates/notary/server/tee/notary-server.sig
|
||||
crates/notary/server/tee/notary-server.manifest
|
||||
crates/notary/server/tee/notary-server.manifest.sgx
|
||||
crates/notary/server/tee/config
|
||||
crates/notary/server/tee/notary-server-sgx.md
|
||||
crates/notary/server/tee/README.md
|
||||
if-no-files-found: error
|
||||
|
||||
- name: Attest Build Provenance
|
||||
@@ -305,24 +308,17 @@ jobs:
|
||||
CONTAINER_REGISTRY: ghcr.io
|
||||
if: github.ref == 'refs/heads/dev' || (startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '.'))
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
sparse-checkout: './crates/notary/server/tee/notary-server-sgx.Dockerfile'
|
||||
|
||||
- name: Download notary-server-sgx.zip from gramine-sgx job
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: notary-server-sgx.zip
|
||||
path: ./notary-server-sgx
|
||||
|
||||
- name: Create Dockerfile
|
||||
run: |
|
||||
cat <<EOF > ./Dockerfile
|
||||
FROM gramineproject/gramine:latest
|
||||
WORKDIR /work
|
||||
COPY ./notary-server-sgx /work
|
||||
RUN chmod +x /work/notary-server
|
||||
LABEL org.opencontainers.image.source=https://github.com/tlsnotary/tlsn
|
||||
LABEL org.opencontainers.image.description="TLSNotary notary server in SGX/Gramine."
|
||||
ENTRYPOINT ["gramine-sgx", "notary-server"]
|
||||
EOF
|
||||
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
@@ -343,7 +339,8 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ steps.meta-notary-server-sgx.outputs.tags }}
|
||||
labels: ${{ steps.meta-notary-server-sgx.outputs.labels }}
|
||||
file: ./Dockerfile
|
||||
file: ./crates/notary/server/tee/notary-server-sgx.Dockerfile
|
||||
|
||||
build_and_publish_notary_server_image:
|
||||
name: Build and publish notary server's image
|
||||
runs-on: ubuntu-latest
|
||||
@@ -379,3 +376,22 @@ jobs:
|
||||
tags: ${{ steps.meta-notary-server.outputs.tags }}
|
||||
labels: ${{ steps.meta-notary-server.outputs.labels }}
|
||||
file: ./crates/notary/server/notary-server.Dockerfile
|
||||
|
||||
create-release-draft:
|
||||
name: Create Release Draft
|
||||
needs: build_and_publish_notary_server_image
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '.')
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Create GitHub Release Draft
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
draft: true
|
||||
tag_name: ${{ github.ref_name }}
|
||||
prerelease: true
|
||||
generate_release_notes: true
|
||||
59
.github/workflows/releng.yml
vendored
59
.github/workflows/releng.yml
vendored
@@ -6,22 +6,57 @@ on:
|
||||
tag:
|
||||
description: 'Tag to publish to NPM'
|
||||
required: true
|
||||
default: '0.1.0-alpha.9'
|
||||
default: 'v0.1.0-alpha.12'
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
|
||||
steps:
|
||||
- name: Download build artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ github.event.inputs.tag }}-tlsn-wasm-pkg
|
||||
path: tlsn-wasm-pkg
|
||||
- name: Find and download tlsn-wasm build from the tagged ci workflow
|
||||
id: find_run
|
||||
run: |
|
||||
# Find the workflow run ID for the tag
|
||||
RUN_ID=$(gh api \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
"/repos/tlsnotary/tlsn/actions/workflows/ci.yml/runs?per_page=100" \
|
||||
--jq '.workflow_runs[] | select(.head_branch == "${{ github.event.inputs.tag }}") | .id')
|
||||
|
||||
- name: NPM Publish for tlsn-wasm
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
run: |
|
||||
cd tlsn-wasm-pkg
|
||||
npm publish
|
||||
if [ -z "$RUN_ID" ]; then
|
||||
echo "No run found for tag ${{ github.event.inputs.tag }}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found run: $RUN_ID"
|
||||
echo "run_id=$RUN_ID" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Find the download URL for the build artifact
|
||||
DOWNLOAD_URL=$(gh api \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
/repos/tlsnotary/tlsn/actions/runs/${RUN_ID}/artifacts \
|
||||
--jq '.artifacts[] | select(.name == "${{ github.event.inputs.tag }}-tlsn-wasm-pkg") | .archive_download_url')
|
||||
|
||||
if [ -z "$DOWNLOAD_URL" ]; then
|
||||
echo "No download url for build artifact ${{ github.event.inputs.tag }}-tlsn-wasm-pkg in run $RUN_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Download and unzip the build artifact
|
||||
mkdir tlsn-wasm-pkg
|
||||
curl -L -H "Authorization: Bearer ${GH_TOKEN}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-o tlsn-wasm-pkg.zip \
|
||||
${DOWNLOAD_URL}
|
||||
unzip -q tlsn-wasm-pkg.zip -d tlsn-wasm-pkg
|
||||
|
||||
|
||||
- name: NPM Publish for tlsn-wasm
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
run: |
|
||||
cd tlsn-wasm-pkg
|
||||
echo "//registry.npmjs.org/:_authToken=${NODE_AUTH_TOKEN}" > .npmrc
|
||||
npm publish
|
||||
rm .npmrc
|
||||
|
||||
9
.github/workflows/rustdoc.yml
vendored
9
.github/workflows/rustdoc.yml
vendored
@@ -21,18 +21,13 @@ jobs:
|
||||
toolchain: stable
|
||||
|
||||
- name: "rustdoc"
|
||||
run: cargo doc -p tlsn-core -p tlsn-prover -p tlsn-verifier --no-deps --all-features
|
||||
# --target-dir ${GITHUB_WORKSPACE}/docs
|
||||
run: crates/wasm/build-docs.sh
|
||||
|
||||
# https://dev.to/deciduously/prepare-your-rust-api-docs-for-github-pages-2n5i
|
||||
- name: "Add index file -> tlsn_prover"
|
||||
run: |
|
||||
echo "<meta http-equiv=\"refresh\" content=\"0; url=tlsn_prover\">" > target/doc/index.html
|
||||
|
||||
- name: Deploy
|
||||
uses: peaceiris/actions-gh-pages@v3
|
||||
if: ${{ github.ref == 'refs/heads/dev' }}
|
||||
with:
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
publish_dir: target/doc/
|
||||
publish_dir: target/wasm32-unknown-unknown/doc/
|
||||
# cname: rustdocs.tlsnotary.org
|
||||
|
||||
156
.github/workflows/tee-cd.yml
vendored
156
.github/workflows/tee-cd.yml
vendored
@@ -1,156 +0,0 @@
|
||||
name: azure-tee-release
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
attestations: write
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ref:
|
||||
description: 'git branch'
|
||||
required: false
|
||||
default: 'dev'
|
||||
type: string
|
||||
|
||||
#on:
|
||||
# release:
|
||||
# types: [published]
|
||||
# branches:
|
||||
# - 'releases/**'
|
||||
|
||||
env:
|
||||
GIT_COMMIT_HASH: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
GIT_COMMIT_TIMESTAMP: ${{ github.event.repository.updated_at}}
|
||||
REGISTRY: notaryserverbuilds.azurecr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
jobs:
|
||||
update-reverse-proxy:
|
||||
permissions:
|
||||
contents: write
|
||||
environment: tee
|
||||
runs-on: [self-hosted, linux]
|
||||
outputs:
|
||||
teeport: ${{ steps.portbump.outputs.newport}}
|
||||
deploy: ${{ steps.portbump.outputs.deploy}}
|
||||
steps:
|
||||
- name: checkout repository
|
||||
uses: actions/checkout@v4
|
||||
- name: update caddyfile
|
||||
id: portbump
|
||||
env:
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name || inputs.ref }}
|
||||
run: |
|
||||
echo "tag: $RELEASE_TAG"
|
||||
NEXT_PORT=$(bash cd-scripts/tee/azure/updateproxy.sh 'cd-scripts/tee/azure/Caddyfile' $RELEASE_TAG)
|
||||
echo "newport=$NEXT_PORT" >> $GITHUB_OUTPUT
|
||||
echo "new deploy port: $NEXT_PORT 🚀" >> $GITHUB_STEP_SUMMARY
|
||||
chmod +r -R cd-scripts/tee/azure/
|
||||
- name: Deploy updated Caddyfile to server
|
||||
if: ${{ steps.portbump.outputs.deploy == 'new' }}
|
||||
uses: appleboy/scp-action@v0.1.7
|
||||
with:
|
||||
host: ${{ secrets.AZURE_TEE_PROD_HOST }}
|
||||
username: ${{ secrets.AZURE_PROD_TEE_USERNAME }}
|
||||
key: ${{ secrets.AZURE_TEE_PROD_KEY }}
|
||||
source: "cd-scripts/tee/azure/Caddyfile"
|
||||
target: "~/"
|
||||
- name: Reload Caddy on server
|
||||
if: ${{ steps.portbump.outputs.deploy == 'new' }}
|
||||
uses: appleboy/ssh-action@v1.0.3
|
||||
with:
|
||||
host: ${{ secrets.AZURE_TEE_PROD_HOST }}
|
||||
username: ${{ secrets.AZURE_PROD_TEE_USERNAME }}
|
||||
key: ${{ secrets.AZURE_TEE_PROD_KEY }}
|
||||
script: |
|
||||
sudo cp ~/cd-scripts/tee/azure/Caddyfile /etc/caddy/Caddyfile
|
||||
sudo systemctl reload caddy
|
||||
build-measure:
|
||||
environment: tee
|
||||
runs-on: [self-hosted, linux]
|
||||
needs: [ update-reverse-proxy ]
|
||||
container:
|
||||
image: notaryserverbuilds.azurecr.io/prod/gramine
|
||||
credentials:
|
||||
username: notaryserverbuilds
|
||||
password: ${{ secrets.AZURE_CR_BUILDS_PW }}
|
||||
env:
|
||||
GIT_COMMIT_HASH: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
volumes:
|
||||
- /var/run/aesmd/aesm.socket:/var/run/aesmd/aesm.socket
|
||||
options: "--device /dev/sgx_enclave"
|
||||
steps:
|
||||
- name: get code
|
||||
uses: actions/checkout@v4
|
||||
- name: sccache
|
||||
if: github.event_name != 'release'
|
||||
# && github.event_name != 'workflow_dispatch'
|
||||
uses: mozilla-actions/sccache-action@v0.0.6
|
||||
- name: set rust env for scc
|
||||
if: github.event_name != 'release'
|
||||
# && github.event_name != 'workflow_dispatch'
|
||||
run: |
|
||||
echo "SCCACHE_GHA_ENABLED=true" >> $GITHUB_ENV
|
||||
echo "RUSTC_WRAPPER=sccache" >> $GITHUB_ENV
|
||||
- name: reverse proxy port
|
||||
run: echo "${{needs.update-reverse-proxy.outputs.teeport}}" | tee >> $GITHUB_STEP_SUMMARY
|
||||
- name: get hardware measurement
|
||||
working-directory: ${{ github.workspace }}/crates/notary/server/tee
|
||||
run: |
|
||||
chmod +x ../../../../.github/scripts/gramine.sh && ../../../../.github/scripts/gramine.sh sgx
|
||||
artifact-deploy:
|
||||
environment: tee
|
||||
runs-on: [self-hosted, linux]
|
||||
needs: [ build-measure, update-reverse-proxy ]
|
||||
steps:
|
||||
- name: auth to registry
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: notaryserverbuilds.azurecr.io
|
||||
username: notaryserverbuilds
|
||||
password: ${{ secrets.AZURE_CR_BUILDS_PW }}
|
||||
- name: get code
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Get Git commit timestamps
|
||||
run: echo "TIMESTAMP=$(git log -1 --pretty=%ct)" >> $GITHUB_ENV
|
||||
- name: Build and push
|
||||
id: deploypush
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
provenance: mode=max
|
||||
no-cache: true
|
||||
context: ${{ github.workspace }}/crates/notary/server/tee
|
||||
push: true
|
||||
tags: notaryserverbuilds.azurecr.io/prod/notary-sgx:${{ env.GIT_COMMIT_HASH }}
|
||||
labels: ${{needs.update-reverse-proxy.outputs.teeport}}
|
||||
env:
|
||||
# reproducible builds: https://github.com/moby/buildkit/blob/master/docs/build-repro.md#source_date_epoch
|
||||
SOURCE_DATE_EPOCH: ${{ env.TIMESTAMP }}
|
||||
- name: Generate SBOM
|
||||
uses: anchore/sbom-action@v0
|
||||
with:
|
||||
image: notaryserverbuilds.azurecr.io/prod/notary-sgx:${{ env.GIT_COMMIT_HASH }}
|
||||
format: 'cyclonedx-json'
|
||||
output-file: 'sbom.cyclonedx.json'
|
||||
# attestation section ::
|
||||
# https://docs.docker.com/build/ci/github-actions/attestations/
|
||||
- name: Attest
|
||||
uses: actions/attest-build-provenance@v1
|
||||
with:
|
||||
subject-name: notaryserverbuilds.azurecr.io/prod/notary-sgx
|
||||
subject-digest: ${{ steps.deploypush.outputs.digest }}
|
||||
push-to-registry: true
|
||||
-
|
||||
name: run
|
||||
run: |
|
||||
if [[ ${{ needs.update-reverse-proxy.outputs.deploy }} == 'new' ]]; then
|
||||
docker run --device /dev/sgx_enclave --device /dev/sgx_provision --volume=/var/run/aesmd/aesm.socket:/var/run/aesmd/aesm.socket -p ${{needs.update-reverse-proxy.outputs.teeport}}:7047 notaryserverbuilds.azurecr.io/prod/notary-sgx:${{ env.GIT_COMMIT_HASH }} &
|
||||
else
|
||||
old=$(docker ps --filter "name=${{needs.update-reverse-proxy.outputs.teeport}}")
|
||||
docker rm -f $old
|
||||
docker run --name ${{needs.update-reverse-proxy.outputs.teeport}} --device /dev/sgx_enclave --device /dev/sgx_provision --volume=/var/run/aesmd/aesm.socket:/var/run/aesmd/aesm.socket -p ${{needs.update-reverse-proxy.outputs.teeport}}:7047 notaryserverbuilds.azurecr.io/prod/notary-sgx:${{ env.GIT_COMMIT_HASH }} &
|
||||
fi
|
||||
24
.github/workflows/updatemain.yml
vendored
Normal file
24
.github/workflows/updatemain.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Fast-forward main branch to published release tag
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
jobs:
|
||||
ff-main-to-release:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- name: Checkout main
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
|
||||
- name: Fast-forward main to release tag
|
||||
run: |
|
||||
tag="${{ github.event.release.tag_name }}"
|
||||
git fetch origin "refs/tags/$tag:refs/tags/$tag"
|
||||
git merge --ff-only "refs/tags/$tag"
|
||||
git push origin main
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -3,10 +3,6 @@
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
@@ -32,4 +28,4 @@ Cargo.lock
|
||||
*.log
|
||||
|
||||
# metrics
|
||||
*.csv
|
||||
*.csv
|
||||
|
||||
@@ -61,3 +61,21 @@ Comments for function arguments must adhere to this pattern:
|
||||
/// * `arg2` - The second argument.
|
||||
pub fn compute(...
|
||||
```
|
||||
|
||||
## Cargo.lock
|
||||
|
||||
We check in `Cargo.lock` to ensure reproducible builds. It must be updated whenever `Cargo.toml` changes. The TLSNotary team typically updates `Cargo.lock` in a separate commit after dependency changes.
|
||||
|
||||
If you want to hide `Cargo.lock` changes from your local `git diff`, run:
|
||||
|
||||
```sh
|
||||
git update-index --assume-unchanged Cargo.lock
|
||||
```
|
||||
|
||||
To start tracking changes again:
|
||||
```sh
|
||||
git update-index --no-assume-unchanged Cargo.lock
|
||||
```
|
||||
|
||||
> ⚠️ Note: This only affects your local view. The file is still tracked in the repository and will be checked and used in CI.
|
||||
|
||||
|
||||
9666
Cargo.lock
generated
Normal file
9666
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
85
Cargo.toml
85
Cargo.toml
@@ -1,27 +1,21 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"crates/benches/binary",
|
||||
"crates/benches/browser/core",
|
||||
"crates/benches/browser/native",
|
||||
"crates/benches/browser/wasm",
|
||||
"crates/benches/library",
|
||||
"crates/common",
|
||||
"crates/components/deap",
|
||||
"crates/components/cipher",
|
||||
"crates/components/hmac-sha256",
|
||||
"crates/components/hmac-sha256-circuits",
|
||||
"crates/components/key-exchange",
|
||||
"crates/core",
|
||||
"crates/data-fixtures",
|
||||
"crates/examples",
|
||||
"crates/formats",
|
||||
"crates/notary/client",
|
||||
"crates/notary/common",
|
||||
"crates/notary/server",
|
||||
"crates/notary/tests-integration",
|
||||
"crates/prover",
|
||||
"crates/server-fixture/certs",
|
||||
"crates/server-fixture/server",
|
||||
"crates/tests-integration",
|
||||
"crates/tls/backend",
|
||||
"crates/tls/client",
|
||||
"crates/tls/client-async",
|
||||
@@ -30,29 +24,40 @@ members = [
|
||||
"crates/tls/server-fixture",
|
||||
"crates/verifier",
|
||||
"crates/wasm",
|
||||
"crates/wasm-test-runner",
|
||||
"crates/harness/core",
|
||||
"crates/harness/executor",
|
||||
"crates/harness/runner",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
[workspace.lints.rust]
|
||||
# unsafe_code = "forbid"
|
||||
|
||||
[workspace.lints.clippy]
|
||||
# enum_glob_use = "deny"
|
||||
|
||||
[profile.tests-integration]
|
||||
inherits = "release"
|
||||
opt-level = 1
|
||||
|
||||
[profile.release.package."tlsn-wasm"]
|
||||
opt-level = "z"
|
||||
|
||||
[profile.dev.package."tlsn-wasm"]
|
||||
debug = false
|
||||
|
||||
[workspace.dependencies]
|
||||
notary-client = { path = "crates/notary/client" }
|
||||
notary-common = { path = "crates/notary/common" }
|
||||
notary-server = { path = "crates/notary/server" }
|
||||
tls-server-fixture = { path = "crates/tls/server-fixture" }
|
||||
tlsn-cipher = { path = "crates/components/cipher" }
|
||||
tlsn-benches-browser-core = { path = "crates/benches/browser/core" }
|
||||
tlsn-benches-browser-native = { path = "crates/benches/browser/native" }
|
||||
tlsn-benches-library = { path = "crates/benches/library" }
|
||||
tlsn-common = { path = "crates/common" }
|
||||
tlsn-core = { path = "crates/core" }
|
||||
tlsn-data-fixtures = { path = "crates/data-fixtures" }
|
||||
tlsn-deap = { path = "crates/components/deap" }
|
||||
tlsn-formats = { path = "crates/formats" }
|
||||
tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" }
|
||||
tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" }
|
||||
tlsn-key-exchange = { path = "crates/components/key-exchange" }
|
||||
tlsn-mpc-tls = { path = "crates/mpc-tls" }
|
||||
tlsn-prover = { path = "crates/prover" }
|
||||
@@ -62,28 +67,32 @@ tlsn-tls-backend = { path = "crates/tls/backend" }
|
||||
tlsn-tls-client = { path = "crates/tls/client" }
|
||||
tlsn-tls-client-async = { path = "crates/tls/client-async" }
|
||||
tlsn-tls-core = { path = "crates/tls/core" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6650a95" }
|
||||
tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6650a95" }
|
||||
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
tlsn-harness-core = { path = "crates/harness/core" }
|
||||
tlsn-harness-executor = { path = "crates/harness/executor" }
|
||||
tlsn-harness-runner = { path = "crates/harness/runner" }
|
||||
tlsn-wasm = { path = "crates/wasm" }
|
||||
tlsn-verifier = { path = "crates/verifier" }
|
||||
|
||||
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", tag = "v0.1.0-alpha.2" }
|
||||
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "ccc0057" }
|
||||
|
||||
rangeset = { version = "0.1" }
|
||||
rangeset = { version = "0.2" }
|
||||
serio = { version = "0.2" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6650a95" }
|
||||
spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
uid-mux = { version = "0.2" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6650a95" }
|
||||
websocket-relay = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
|
||||
|
||||
aes = { version = "0.8" }
|
||||
aes-gcm = { version = "0.9" }
|
||||
@@ -94,9 +103,13 @@ axum = { version = "0.8" }
|
||||
bcs = { version = "0.1" }
|
||||
bincode = { version = "1.3" }
|
||||
blake3 = { version = "1.5" }
|
||||
bon = { version = "3.6" }
|
||||
bytes = { version = "1.4" }
|
||||
cfg-if = { version = "1" }
|
||||
chromiumoxide = { version = "0.7" }
|
||||
chrono = { version = "0.4" }
|
||||
cipher = { version = "0.4" }
|
||||
clap = { version = "4.5" }
|
||||
criterion = { version = "0.5" }
|
||||
ctr = { version = "0.9" }
|
||||
derive_builder = { version = "0.12" }
|
||||
@@ -108,13 +121,17 @@ futures = { version = "0.3" }
|
||||
futures-rustls = { version = "0.26" }
|
||||
futures-util = { version = "0.3" }
|
||||
generic-array = { version = "0.14" }
|
||||
ghash = { version = "0.5" }
|
||||
hex = { version = "0.4" }
|
||||
hmac = { version = "0.12" }
|
||||
http = { version = "1.1" }
|
||||
http-body-util = { version = "0.1" }
|
||||
hyper = { version = "1.1" }
|
||||
hyper-util = { version = "0.1" }
|
||||
ipnet = { version = "2.11" }
|
||||
inventory = { version = "0.3" }
|
||||
itybity = { version = "0.2" }
|
||||
js-sys = { version = "0.3" }
|
||||
k256 = { version = "0.13" }
|
||||
log = { version = "0.4" }
|
||||
once_cell = { version = "1.19" }
|
||||
@@ -122,6 +139,7 @@ opaque-debug = { version = "0.3" }
|
||||
p256 = { version = "0.13" }
|
||||
pkcs8 = { version = "0.10" }
|
||||
pin-project-lite = { version = "0.2" }
|
||||
pollster = { version = "0.4" }
|
||||
rand = { version = "0.9" }
|
||||
rand_chacha = { version = "0.9" }
|
||||
rand_core = { version = "0.9" }
|
||||
@@ -142,12 +160,21 @@ thiserror = { version = "1.0" }
|
||||
tokio = { version = "1.38" }
|
||||
tokio-rustls = { version = "0.24" }
|
||||
tokio-util = { version = "0.7" }
|
||||
toml = { version = "0.8" }
|
||||
tower = { version = "0.5" }
|
||||
tower-http = { version = "0.5" }
|
||||
tower-service = { version = "0.3" }
|
||||
tower-util = { version = "0.3.1" }
|
||||
tracing = { version = "0.1" }
|
||||
tracing-subscriber = { version = "0.3" }
|
||||
uuid = { version = "1.4" }
|
||||
wasm-bindgen = { version = "0.2" }
|
||||
wasm-bindgen-futures = { version = "0.4" }
|
||||
web-spawn = { version = "0.2" }
|
||||
web-time = { version = "0.2" }
|
||||
webpki = { version = "0.22" }
|
||||
webpki-roots = { version = "0.26" }
|
||||
ws_stream_tungstenite = { version = "0.14" }
|
||||
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
|
||||
ws_stream_wasm = { git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51" }
|
||||
zeroize = { version = "1.8" }
|
||||
|
||||
31
appspec.yml
31
appspec.yml
@@ -1,31 +0,0 @@
|
||||
# AWS CodeDeploy application specification file
|
||||
version: 0.0
|
||||
os: linux
|
||||
files:
|
||||
- source: /
|
||||
destination: /home/ubuntu/tlsn
|
||||
permissions:
|
||||
- object: /home/ubuntu/tlsn
|
||||
owner: ubuntu
|
||||
group: ubuntu
|
||||
hooks:
|
||||
BeforeInstall:
|
||||
- location: cd-scripts/appspec-scripts/before_install.sh
|
||||
timeout: 300
|
||||
runas: ubuntu
|
||||
AfterInstall:
|
||||
- location: cd-scripts/appspec-scripts/after_install.sh
|
||||
timeout: 300
|
||||
runas: ubuntu
|
||||
ApplicationStart:
|
||||
- location: cd-scripts/appspec-scripts/start_app.sh
|
||||
timeout: 300
|
||||
runas: ubuntu
|
||||
ApplicationStop:
|
||||
- location: cd-scripts/appspec-scripts/stop_app.sh
|
||||
timeout: 300
|
||||
runas: ubuntu
|
||||
ValidateService:
|
||||
- location: cd-scripts/appspec-scripts/validate_app.sh
|
||||
timeout: 300
|
||||
runas: ubuntu
|
||||
@@ -1,35 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
TAG=$(curl http://169.254.169.254/latest/meta-data/tags/instance/stable)
|
||||
APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }')
|
||||
|
||||
if [ $APP_NAME = "stable" ]; then
|
||||
# Prepare directories for stable versions
|
||||
sudo mkdir ~/${APP_NAME}_${TAG}
|
||||
sudo mv ~/tlsn ~/${APP_NAME}_${TAG}
|
||||
sudo mkdir -p ~/${APP_NAME}_${TAG}/tlsn/notary/target/release
|
||||
sudo chown -R ubuntu.ubuntu ~/${APP_NAME}_${TAG}
|
||||
|
||||
# Download .git directory
|
||||
aws s3 cp s3://tlsn-deploy/$APP_NAME/.git ~/${APP_NAME}_${TAG}/tlsn/.git --recursive
|
||||
|
||||
# Download binary
|
||||
aws s3 cp s3://tlsn-deploy/$APP_NAME/notary-server ~/${APP_NAME}_${TAG}/tlsn/notary/target/release
|
||||
chmod +x ~/${APP_NAME}_${TAG}/tlsn/notary/target/release/notary-server
|
||||
else
|
||||
# Prepare directory for dev
|
||||
sudo rm -rf ~/$APP_NAME/tlsn
|
||||
sudo mv ~/tlsn/ ~/$APP_NAME
|
||||
sudo mkdir -p ~/$APP_NAME/tlsn/notary/target/release
|
||||
sudo chown -R ubuntu.ubuntu ~/$APP_NAME
|
||||
|
||||
# Download .git directory
|
||||
aws s3 cp s3://tlsn-deploy/$APP_NAME/.git ~/$APP_NAME/tlsn/.git --recursive
|
||||
|
||||
# Download binary
|
||||
aws s3 cp s3://tlsn-deploy/$APP_NAME/notary-server ~/$APP_NAME/tlsn/notary/target/release
|
||||
chmod +x ~/$APP_NAME/tlsn/notary/target/release/notary-server
|
||||
fi
|
||||
|
||||
exit 0
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }')
|
||||
|
||||
if [ $APP_NAME = "stable" ]; then
|
||||
VERSIONS_DEPLOYED=$(find ~/ -maxdepth 1 -type d -name 'stable_*')
|
||||
VERSIONS_DEPLOYED_COUNT=$(echo $VERSIONS_DEPLOYED | wc -w)
|
||||
|
||||
if [ $VERSIONS_DEPLOYED_COUNT -gt 3 ]; then
|
||||
echo "More than 3 stable versions found"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
if [ ! -d ~/$APP_NAME ]; then
|
||||
mkdir ~/$APP_NAME
|
||||
fi
|
||||
fi
|
||||
|
||||
exit 0
|
||||
@@ -1,26 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Port tagging will also be used to manipulate proxy server via modify_proxy.sh script
|
||||
set -ex
|
||||
|
||||
TAG=$(curl http://169.254.169.254/latest/meta-data/tags/instance/stable)
|
||||
APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }')
|
||||
|
||||
if [ $APP_NAME = "stable" ]; then
|
||||
# Check if all stable ports are in use. If true, terminate the deployment
|
||||
[[ $(netstat -lnt4 | egrep -c ':(7047|7057|7067)\s') -eq 3 ]] && { echo "All stable ports are in use"; exit 1; }
|
||||
STABLE_PORTS="7047 7057 7067"
|
||||
for PORT in $STABLE_PORTS; do
|
||||
PORT_LISTENING=$(netstat -lnt4 | egrep -cw $PORT || true)
|
||||
if [ $PORT_LISTENING -eq 0 ]; then
|
||||
~/${APP_NAME}_${TAG}/tlsn/notary/target/release/notary-server --config-file ~/.notary/${APP_NAME}_${PORT}/config.yaml &> ~/${APP_NAME}_${TAG}/tlsn/notary.log &
|
||||
# Create a tag that will be used for service validation
|
||||
INSTANCE_ID=$(curl http://169.254.169.254/latest/meta-data/instance-id)
|
||||
aws ec2 create-tags --resources $INSTANCE_ID --tags "Key=port,Value=$PORT"
|
||||
break
|
||||
fi
|
||||
done
|
||||
else
|
||||
~/$APP_NAME/tlsn/notary/target/release/notary-server --config-file ~/.notary/$APP_NAME/config.yaml &> ~/$APP_NAME/tlsn/notary.log &
|
||||
fi
|
||||
|
||||
exit 0
|
||||
@@ -1,36 +0,0 @@
|
||||
#!/bin/bash
|
||||
# AWS CodeDeploy hook sequence: https://docs.aws.amazon.com/codedeploy/latest/userguide/reference-appspec-file-structure-hooks.html#appspec-hooks-server
|
||||
set -ex
|
||||
|
||||
APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }')
|
||||
|
||||
if [ $APP_NAME = "stable" ]; then
|
||||
VERSIONS_DEPLOYED=$(find ~/ -maxdepth 1 -type d -name 'stable_*')
|
||||
VERSIONS_DEPLOYED_COUNT=$(echo $VERSIONS_DEPLOYED | wc -w)
|
||||
|
||||
# Remove oldest version if exists
|
||||
if [ $VERSIONS_DEPLOYED_COUNT -eq 3 ]; then
|
||||
echo "Candidate versions to be removed:"
|
||||
OLDEST_DIR=""
|
||||
OLDEST_TIME=""
|
||||
|
||||
for DIR in $VERSIONS_DEPLOYED; do
|
||||
TIME=$(stat -c %W $DIR)
|
||||
|
||||
if [ -z $OLDEST_TIME ] || [ $TIME -lt $OLDEST_TIME ]; then
|
||||
OLDEST_DIR=$DIR
|
||||
OLDEST_TIME=$TIME
|
||||
fi
|
||||
done
|
||||
|
||||
echo "The oldest version is running under: $OLDEST_DIR"
|
||||
PID=$(lsof $OLDEST_DIR/tlsn/notary/target/release/notary-server | awk '{ print $2 }' | tail -1)
|
||||
kill -15 $PID || true
|
||||
rm -rf $OLDEST_DIR
|
||||
fi
|
||||
else
|
||||
PID=$(pgrep -f notary.*$APP_NAME)
|
||||
kill -15 $PID || true
|
||||
fi
|
||||
|
||||
exit 0
|
||||
@@ -1,21 +0,0 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Verify proccess is running
|
||||
APP_NAME=$(echo $APPLICATION_NAME | awk -F- '{ print $2 }')
|
||||
|
||||
# Verify that listening sockets exist
|
||||
if [ $APP_NAME = "stable" ]; then
|
||||
PORT=$(curl http://169.254.169.254/latest/meta-data/tags/instance/port)
|
||||
ps -ef | grep notary.*$APP_NAME.*$PORT | grep -v grep
|
||||
[ $? -eq 0 ] || exit 1
|
||||
else
|
||||
PORT=7048
|
||||
pgrep -f notary.*$APP_NAME
|
||||
[ $? -eq 0 ] || exit 1
|
||||
fi
|
||||
|
||||
EXPOSED_PORTS=$(netstat -lnt4 | egrep -cw $PORT)
|
||||
[ $EXPOSED_PORTS -eq 1 ] || exit 1
|
||||
|
||||
exit 0
|
||||
@@ -1,14 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script is executed on proxy side, in order to assign the available port to latest stable version
|
||||
set -e
|
||||
|
||||
PORT=$1
|
||||
VERSION=$2
|
||||
|
||||
sed -i "/# Port $PORT/{n;s/v[0-9].[0-9].[0-9]-[a-z]*.[0-9]*/$VERSION/g}" /etc/nginx/sites-available/tlsnotary-pse
|
||||
sed -i "/# Port $PORT/{n;n;s/v[0-9].[0-9].[0-9]-[a-z]*.[0-9]*/$VERSION/g}" /etc/nginx/sites-available/tlsnotary-pse
|
||||
|
||||
nginx -t
|
||||
nginx -s reload
|
||||
|
||||
exit 0
|
||||
@@ -1,90 +0,0 @@
|
||||
#
|
||||
# global block =>
|
||||
# email is for acme
|
||||
# # # #
|
||||
{
|
||||
key_type p256
|
||||
email mac@pse.dev # for acme
|
||||
servers {
|
||||
metrics
|
||||
}
|
||||
log {
|
||||
output stdout
|
||||
format console {
|
||||
time_format common_log
|
||||
time_local
|
||||
}
|
||||
level DEBUG
|
||||
}
|
||||
}
|
||||
|
||||
#
|
||||
# server block, acme turned on (default when using dns)
|
||||
# reverse proxy with fail_duration + lb will try upstreams sequentially (fallback)
|
||||
# e.g. => `reverse_proxy :4000 :5000 10.10.10.10:1000 tlsnotary.org:443`
|
||||
# will always deliver to :4000 if its up, but if :4000 is down for more than 4s it trys the next one
|
||||
# # # #
|
||||
|
||||
notary.codes {
|
||||
handle_path /v0.1.0-alpha.8* {
|
||||
reverse_proxy :4003 :3333 {
|
||||
lb_try_duration 4s
|
||||
fail_duration 10s
|
||||
lb_policy header X-Upstream {
|
||||
fallback first
|
||||
}
|
||||
}
|
||||
}
|
||||
handle_path /v0.1.0-alpha.7* {
|
||||
reverse_proxy :4002 :3333 {
|
||||
lb_try_duration 4s
|
||||
fail_duration 10s
|
||||
lb_policy header X-Upstream {
|
||||
fallback first
|
||||
}
|
||||
}
|
||||
}
|
||||
handle_path /v0.1.0-alpha.6* {
|
||||
reverse_proxy :4001 :3333 {
|
||||
lb_try_duration 4s
|
||||
fail_duration 10s
|
||||
lb_policy header X-Upstream {
|
||||
fallback first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle_path /nightly* {
|
||||
reverse_proxy :3333 {
|
||||
lb_try_duration 4s
|
||||
fail_duration 10s
|
||||
lb_policy header X-Upstream {
|
||||
fallback first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle_path /proxy* {
|
||||
reverse_proxy :55688 proxy.notary.codes:443 {
|
||||
lb_try_duration 4s
|
||||
fail_duration 10s
|
||||
lb_policy header X-Upstream {
|
||||
fallback first
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
handle {
|
||||
root * /srv
|
||||
file_server
|
||||
}
|
||||
|
||||
handle_errors {
|
||||
@404 {
|
||||
expression {http.error.status_code} == 404
|
||||
}
|
||||
rewrite @404 /index.html
|
||||
file_server
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: caddy
|
||||
static_configs:
|
||||
- targets: ['localhost:2019']
|
||||
@@ -1,84 +0,0 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Variables (Update these as needed)x
|
||||
CADDYFILE=${1:-/etc/caddy/Caddyfile} # Path to your Caddyfile
|
||||
GIT_COMMIT_HASH=${2:-dev}
|
||||
BASE_PORT=6061 # The starting port for your reverse_proxy directives
|
||||
|
||||
# Function to check if handle_path for the given commit hash exists
|
||||
handle_path_exists() {
|
||||
local commit_hash=$1
|
||||
#echo "handle_path_exists $1 -- CADDYFILE: $CADDYFILE"
|
||||
grep -q "handle_path /${commit_hash}\*" "$CADDYFILE"
|
||||
}
|
||||
|
||||
# Function to extract the port for a given commit hash
|
||||
extract_port_for_commit() {
|
||||
local commit_hash=$1
|
||||
#echo "extract_port_for_commit $1 -- 2: $2"
|
||||
grep -Pzo "handle_path /${commit_hash}\* \{\n\s*reverse_proxy :(.*) " "$CADDYFILE" | grep -Poa "reverse_proxy :(.*) " | awk '{print $2}'
|
||||
}
|
||||
|
||||
# Function to get the last port in the Caddyfile
|
||||
get_last_port() {
|
||||
grep -Po "reverse_proxy :([0-9]+)" "$CADDYFILE" | awk -F: '{print $2}' | sort -n | tail -1
|
||||
}
|
||||
|
||||
# Function to add a new handle_path block with incremented port inside notary.codes block
|
||||
add_new_handle_path() {
|
||||
local new_port=$1
|
||||
local commit_hash=$2
|
||||
|
||||
# Use a temporary file for inserting the handle_path block
|
||||
tmp_file=$(mktemp)
|
||||
|
||||
# Add the new handle_path in the notary.codes block
|
||||
awk -v port="$new_port" -v hash="$commit_hash" '
|
||||
/notary\.codes \{/ {
|
||||
print;
|
||||
print " handle_path /" hash "* {";
|
||||
print " reverse_proxy :" port " :3333 {";
|
||||
print " lb_try_duration 4s";
|
||||
print " fail_duration 10s";
|
||||
print " lb_policy header X-Upstream {";
|
||||
print " fallback first";
|
||||
print " }";
|
||||
print " }";
|
||||
print " }";
|
||||
next;
|
||||
}
|
||||
{ print }
|
||||
' "$CADDYFILE" > "$tmp_file"
|
||||
|
||||
# Overwrite the original Caddyfile with the updated content
|
||||
mv "$tmp_file" "$CADDYFILE"
|
||||
|
||||
}
|
||||
#git action perms +r
|
||||
chmod 664 cd-scripts/tee/azure/Caddyfile
|
||||
|
||||
# Check if the commit hash already exists in a handle_path
|
||||
if handle_path_exists "$GIT_COMMIT_HASH"; then
|
||||
existing_port=$(extract_port_for_commit "$GIT_COMMIT_HASH")
|
||||
echo "${existing_port:1}"
|
||||
exit 0
|
||||
else
|
||||
# Get the last port used and increment it
|
||||
last_port=$(get_last_port)
|
||||
if [[ -z "$last_port" ]]; then
|
||||
last_port=$BASE_PORT
|
||||
fi
|
||||
new_port=$((last_port + 1))
|
||||
|
||||
# Add the new handle_path block inside notary.codes block
|
||||
add_new_handle_path "$new_port" "$GIT_COMMIT_HASH"
|
||||
echo $new_port
|
||||
# commit the changes
|
||||
git config user.name github-actions
|
||||
git config user.email github-actions@github.com
|
||||
git add -A
|
||||
git commit --quiet --allow-empty -m "azure tee reverse proxy => port:$NEXT_PORT/${RELEASE_TAG}"
|
||||
git push --quiet
|
||||
echo "deploy=new" >> $GITHUB_OUTPUT
|
||||
exit 0
|
||||
fi
|
||||
@@ -1,70 +0,0 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "tlsn-benches"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
# Enables benchmarks in the browser.
|
||||
browser-bench = ["tlsn-benches-browser-native"]
|
||||
|
||||
[dependencies]
|
||||
mpz-common = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-garble = { workspace = true }
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
tlsn-benches-library = { workspace = true }
|
||||
tlsn-benches-browser-native = { workspace = true, optional = true}
|
||||
tlsn-common = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-hmac-sha256 = { workspace = true }
|
||||
tlsn-prover = { workspace = true }
|
||||
tlsn-server-fixture = { workspace = true }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
tlsn-verifier = { workspace = true }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
charming = {version = "0.3.1", features = ["ssr"]}
|
||||
csv = "1.3.0"
|
||||
dhat = { version = "0.3.3" }
|
||||
env_logger = { version = "0.6.0", default-features = false }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
"rt",
|
||||
"rt-multi-thread",
|
||||
"macros",
|
||||
"net",
|
||||
"io-std",
|
||||
"fs",
|
||||
] }
|
||||
tokio-util = { workspace = true }
|
||||
toml = "0.8.11"
|
||||
tracing-subscriber = {workspace = true, features = ["env-filter"]}
|
||||
|
||||
[[bin]]
|
||||
name = "bench"
|
||||
path = "bin/bench.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "prover"
|
||||
path = "bin/prover.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "prover-memory"
|
||||
path = "bin/prover_memory.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "verifier"
|
||||
path = "bin/verifier.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "verifier-memory"
|
||||
path = "bin/verifier_memory.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "plot"
|
||||
path = "bin/plot.rs"
|
||||
@@ -1,53 +0,0 @@
|
||||
# TLSNotary bench utilities
|
||||
|
||||
This crate provides utilities for benchmarking protocol performance under various network conditions and usage patterns.
|
||||
|
||||
As the protocol is mostly IO bound, it's important to track how it performs in low bandwidth and/or high latency environments. To do this we set up temporary network namespaces and add virtual ethernet interfaces which we can control using the linux `tc` (Traffic Control) utility.
|
||||
|
||||
## Configuration
|
||||
|
||||
See the `bench.toml` file for benchmark configurations.
|
||||
|
||||
## Preliminaries
|
||||
|
||||
To run the benchmarks you will need `iproute2` installed, eg:
|
||||
```sh
|
||||
sudo apt-get install iproute2 -y
|
||||
```
|
||||
|
||||
## Running benches
|
||||
|
||||
Running the benches requires root privileges because they will set up virtual interfaces. The script is designed to fully clean up when the benches are done, but run them at your own risk.
|
||||
|
||||
#### Native benches
|
||||
|
||||
Make sure you're in the `crates/benches/` directory, build the binaries, and then run the script:
|
||||
|
||||
```sh
|
||||
cd binary
|
||||
cargo build --release
|
||||
sudo ./bench.sh
|
||||
```
|
||||
|
||||
#### Browser benches
|
||||
|
||||
(Note, we recommend running browser benches inside a docker container (see docker.md) to avoid
|
||||
facing incompatibility issues observed in the latest versions of Chrome.)
|
||||
|
||||
With a Chrome browser installed on your system, make sure you're in the `crates/benches/`
|
||||
directory, build the wasm module, build the binaries, and then run the script:
|
||||
```sh
|
||||
cd browser/wasm
|
||||
rustup run nightly wasm-pack build --release --target web
|
||||
cd ../../binary
|
||||
cargo build --release --features browser-bench
|
||||
sudo ./bench.sh
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
After you run the benches you will see a `metrics.csv` file in the working directory. It will be owned by `root`, so you probably want to run
|
||||
|
||||
```sh
|
||||
sudo chown $USER metrics.csv
|
||||
```
|
||||
@@ -1,13 +0,0 @@
|
||||
#! /bin/bash
|
||||
|
||||
# Check if we are running as root.
|
||||
if [ "$EUID" -ne 0 ]; then
|
||||
echo "This script must be run as root"
|
||||
exit
|
||||
fi
|
||||
|
||||
# Run the benchmark binary.
|
||||
../../../target/release/bench
|
||||
|
||||
# Plot the results.
|
||||
../../../target/release/plot metrics.csv
|
||||
@@ -1,45 +0,0 @@
|
||||
[[benches]]
|
||||
name = "latency"
|
||||
upload = 250
|
||||
upload-delay = [10, 25, 50]
|
||||
download = 250
|
||||
download-delay = [10, 25, 50]
|
||||
upload-size = 1024
|
||||
download-size = 4096
|
||||
defer-decryption = true
|
||||
memory-profile = false
|
||||
|
||||
[[benches]]
|
||||
name = "download_bandwidth"
|
||||
upload = 250
|
||||
upload-delay = 25
|
||||
download = [10, 25, 50, 100, 250]
|
||||
download-delay = 25
|
||||
upload-size = 1024
|
||||
download-size = 4096
|
||||
defer-decryption = true
|
||||
memory-profile = false
|
||||
|
||||
[[benches]]
|
||||
name = "upload_bandwidth"
|
||||
upload = [10, 25, 50, 100, 250]
|
||||
upload-delay = 25
|
||||
download = 250
|
||||
download-delay = 25
|
||||
upload-size = 1024
|
||||
download-size = 4096
|
||||
defer-decryption = [false, true]
|
||||
memory-profile = false
|
||||
|
||||
[[benches]]
|
||||
name = "download_volume"
|
||||
upload = 250
|
||||
upload-delay = 25
|
||||
download = 250
|
||||
download-delay = 25
|
||||
upload-size = 1024
|
||||
# Setting download-size higher than 45000 will cause a `Maximum call stack size exceeded`
|
||||
# error in the browser.
|
||||
download-size = [1024, 4096, 16384, 45000]
|
||||
defer-decryption = true
|
||||
memory-profile = true
|
||||
@@ -1,55 +0,0 @@
|
||||
FROM rust AS builder
|
||||
WORKDIR /usr/src/tlsn
|
||||
COPY . .
|
||||
|
||||
ARG BENCH_TYPE=native
|
||||
|
||||
RUN \
|
||||
if [ "$BENCH_TYPE" = "browser" ]; then \
|
||||
# ring's build script needs clang.
|
||||
apt update && apt install -y clang; \
|
||||
rustup install nightly; \
|
||||
rustup component add rust-src --toolchain nightly; \
|
||||
cargo install wasm-pack; \
|
||||
cd crates/benches/browser/wasm; \
|
||||
rustup run nightly wasm-pack build --release --target web; \
|
||||
cd ../../binary; \
|
||||
cargo build --release --features browser-bench; \
|
||||
else \
|
||||
cd crates/benches/binary; \
|
||||
cargo build --release; \
|
||||
fi
|
||||
|
||||
FROM debian:latest
|
||||
|
||||
ARG BENCH_TYPE=native
|
||||
|
||||
RUN apt update && apt upgrade -y && apt install -y --no-install-recommends \
|
||||
iproute2 \
|
||||
sudo
|
||||
|
||||
RUN \
|
||||
if [ "$BENCH_TYPE" = "browser" ]; then \
|
||||
# Using Chromium since Chrome for Linux is not available on ARM.
|
||||
apt install -y chromium; \
|
||||
fi
|
||||
|
||||
RUN apt clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder \
|
||||
["/usr/src/tlsn/target/release/bench", \
|
||||
"/usr/src/tlsn/target/release/prover", \
|
||||
"/usr/src/tlsn/target/release/prover-memory", \
|
||||
"/usr/src/tlsn/target/release/verifier", \
|
||||
"/usr/src/tlsn/target/release/verifier-memory", \
|
||||
"/usr/src/tlsn/target/release/plot", \
|
||||
"/usr/local/bin/"]
|
||||
|
||||
ENV PROVER_PATH="/usr/local/bin/prover"
|
||||
ENV VERIFIER_PATH="/usr/local/bin/verifier"
|
||||
ENV PROVER_MEMORY_PATH="/usr/local/bin/prover-memory"
|
||||
ENV VERIFIER_MEMORY_PATH="/usr/local/bin/verifier-memory"
|
||||
|
||||
VOLUME [ "/benches" ]
|
||||
WORKDIR "/benches"
|
||||
CMD ["/bin/bash", "-c", "bench && bench --memory-profiling && plot /benches/metrics.csv && cat /benches/metrics.csv"]
|
||||
@@ -1,2 +0,0 @@
|
||||
# exclude any /target folders
|
||||
**/target*
|
||||
@@ -1,62 +0,0 @@
|
||||
use std::{env, process::Command, thread, time::Duration};
|
||||
|
||||
use tlsn_benches::{clean_up, set_up};
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = env::args().collect();
|
||||
let is_memory_profiling = args.contains(&"--memory-profiling".to_string());
|
||||
|
||||
let (prover_path, verifier_path) = if is_memory_profiling {
|
||||
(
|
||||
std::env::var("PROVER_MEMORY_PATH")
|
||||
.unwrap_or_else(|_| "../../../target/release/prover-memory".to_string()),
|
||||
std::env::var("VERIFIER_MEMORY_PATH")
|
||||
.unwrap_or_else(|_| "../../../target/release/verifier-memory".to_string()),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
std::env::var("PROVER_PATH")
|
||||
.unwrap_or_else(|_| "../../../target/release/prover".to_string()),
|
||||
std::env::var("VERIFIER_PATH")
|
||||
.unwrap_or_else(|_| "../../../target/release/verifier".to_string()),
|
||||
)
|
||||
};
|
||||
|
||||
if let Err(e) = set_up() {
|
||||
println!("Error setting up: {}", e);
|
||||
clean_up();
|
||||
}
|
||||
|
||||
// Run prover and verifier binaries in parallel.
|
||||
let Ok(mut verifier) = Command::new("ip")
|
||||
.arg("netns")
|
||||
.arg("exec")
|
||||
.arg("verifier-ns")
|
||||
.arg(verifier_path)
|
||||
.spawn()
|
||||
else {
|
||||
println!("Failed to start verifier");
|
||||
return clean_up();
|
||||
};
|
||||
|
||||
// Allow the verifier some time to start listening before the prover attempts to
|
||||
// connect.
|
||||
thread::sleep(Duration::from_secs(1));
|
||||
|
||||
let Ok(mut prover) = Command::new("ip")
|
||||
.arg("netns")
|
||||
.arg("exec")
|
||||
.arg("prover-ns")
|
||||
.arg(prover_path)
|
||||
.spawn()
|
||||
else {
|
||||
println!("Failed to start prover");
|
||||
return clean_up();
|
||||
};
|
||||
|
||||
// Wait for both to finish.
|
||||
_ = prover.wait();
|
||||
_ = verifier.wait();
|
||||
|
||||
clean_up();
|
||||
}
|
||||
@@ -1,248 +0,0 @@
|
||||
use tlsn_benches::metrics::Metrics;
|
||||
|
||||
use charming::{
|
||||
component::{
|
||||
Axis, DataView, Feature, Legend, Restore, SaveAsImage, Title, Toolbox, ToolboxDataZoom,
|
||||
},
|
||||
element::{NameLocation, Orient, Tooltip, Trigger},
|
||||
series::{Line, Scatter},
|
||||
theme::Theme,
|
||||
Chart, HtmlRenderer,
|
||||
};
|
||||
use csv::Reader;
|
||||
|
||||
const THEME: Theme = Theme::Default;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let csv_file = std::env::args()
|
||||
.nth(1)
|
||||
.expect("Usage: plot <path_to_csv_file>");
|
||||
|
||||
let mut rdr = Reader::from_path(csv_file)?;
|
||||
|
||||
// Prepare data for plotting.
|
||||
let all_data: Vec<Metrics> = rdr
|
||||
.deserialize::<Metrics>()
|
||||
.collect::<Result<Vec<_>, _>>()?; // Attempt to collect all results, return an error if any fail.
|
||||
|
||||
let _chart = runtime_vs_latency(&all_data)?;
|
||||
let _chart = runtime_vs_bandwidth(&all_data)?;
|
||||
|
||||
// Memory profiling is not compatible with browser benches.
|
||||
if cfg!(not(feature = "browser-bench")) {
|
||||
let _chart = download_size_vs_memory(&all_data)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_size_vs_memory(all_data: &[Metrics]) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
const TITLE: &str = "Download Size vs Memory";
|
||||
|
||||
let prover_kind: String = all_data
|
||||
.first()
|
||||
.map(|s| s.kind.clone().into())
|
||||
.unwrap_or_default();
|
||||
|
||||
let data: Vec<Vec<f32>> = all_data
|
||||
.iter()
|
||||
.filter(|record| record.name == "download_volume" && record.heap_max_bytes.is_some())
|
||||
.map(|record| {
|
||||
vec![
|
||||
record.download_size as f32,
|
||||
record.heap_max_bytes.unwrap() as f32 / 1024.0 / 1024.0,
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
|
||||
// https://github.com/yuankunzhang/charming
|
||||
let chart = Chart::new()
|
||||
.title(
|
||||
Title::new()
|
||||
.text(TITLE)
|
||||
.subtext(format!("{} Prover", prover_kind)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(Legend::new().orient(Orient::Vertical))
|
||||
.toolbox(
|
||||
Toolbox::new().show(true).feature(
|
||||
Feature::new()
|
||||
.save_as_image(SaveAsImage::new())
|
||||
.restore(Restore::new())
|
||||
.data_zoom(ToolboxDataZoom::new().y_axis_index("none"))
|
||||
.data_view(DataView::new().read_only(false)),
|
||||
),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Download Size (bytes)")
|
||||
.name_gap(30)
|
||||
.name_location(NameLocation::Center),
|
||||
)
|
||||
.y_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Heap Memory (Mbytes)")
|
||||
.name_gap(40)
|
||||
.name_location(NameLocation::Middle),
|
||||
)
|
||||
.series(
|
||||
Scatter::new()
|
||||
.name("Allocated Heap Memory")
|
||||
.symbol_size(10)
|
||||
.data(data),
|
||||
);
|
||||
|
||||
// Save the chart as HTML file.
|
||||
HtmlRenderer::new(TITLE, 1000, 800)
|
||||
.theme(THEME)
|
||||
.save(&chart, "download_size_vs_memory.html")
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
|
||||
fn runtime_vs_latency(all_data: &[Metrics]) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
const TITLE: &str = "Runtime vs Latency";
|
||||
|
||||
let prover_kind: String = all_data
|
||||
.first()
|
||||
.map(|s| s.kind.clone().into())
|
||||
.unwrap_or_default();
|
||||
|
||||
let data: Vec<Vec<f32>> = all_data
|
||||
.iter()
|
||||
.filter(|record| record.name == "latency")
|
||||
.map(|record| {
|
||||
let total_delay = record.upload_delay + record.download_delay; // Calculate the sum of upload and download delays.
|
||||
vec![total_delay as f32, record.runtime as f32]
|
||||
})
|
||||
.collect();
|
||||
|
||||
// https://github.com/yuankunzhang/charming
|
||||
let chart = Chart::new()
|
||||
.title(
|
||||
Title::new()
|
||||
.text(TITLE)
|
||||
.subtext(format!("{} Prover", prover_kind)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(Legend::new().orient(Orient::Vertical))
|
||||
.toolbox(
|
||||
Toolbox::new().show(true).feature(
|
||||
Feature::new()
|
||||
.save_as_image(SaveAsImage::new())
|
||||
.restore(Restore::new())
|
||||
.data_zoom(ToolboxDataZoom::new().y_axis_index("none"))
|
||||
.data_view(DataView::new().read_only(false)),
|
||||
),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Upload + Download Latency (ms)")
|
||||
.name_location(NameLocation::Center),
|
||||
)
|
||||
.y_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Runtime (s)")
|
||||
.name_location(NameLocation::Middle),
|
||||
)
|
||||
.series(
|
||||
Scatter::new()
|
||||
.name("Combined Latency")
|
||||
.symbol_size(10)
|
||||
.data(data),
|
||||
);
|
||||
|
||||
// Save the chart as HTML file.
|
||||
HtmlRenderer::new(TITLE, 1000, 800)
|
||||
.theme(THEME)
|
||||
.save(&chart, "runtime_vs_latency.html")
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
|
||||
fn runtime_vs_bandwidth(all_data: &[Metrics]) -> Result<Chart, Box<dyn std::error::Error>> {
|
||||
const TITLE: &str = "Runtime vs Bandwidth";
|
||||
|
||||
let prover_kind: String = all_data
|
||||
.first()
|
||||
.map(|s| s.kind.clone().into())
|
||||
.unwrap_or_default();
|
||||
|
||||
let download_data: Vec<Vec<f32>> = all_data
|
||||
.iter()
|
||||
.filter(|record| record.name == "download_bandwidth")
|
||||
.map(|record| vec![record.download as f32, record.runtime as f32])
|
||||
.collect();
|
||||
let upload_deferred_data: Vec<Vec<f32>> = all_data
|
||||
.iter()
|
||||
.filter(|record| record.name == "upload_bandwidth" && record.defer_decryption)
|
||||
.map(|record| vec![record.upload as f32, record.runtime as f32])
|
||||
.collect();
|
||||
let upload_non_deferred_data: Vec<Vec<f32>> = all_data
|
||||
.iter()
|
||||
.filter(|record| record.name == "upload_bandwidth" && !record.defer_decryption)
|
||||
.map(|record| vec![record.upload as f32, record.runtime as f32])
|
||||
.collect();
|
||||
|
||||
// https://github.com/yuankunzhang/charming
|
||||
let chart = Chart::new()
|
||||
.title(
|
||||
Title::new()
|
||||
.text(TITLE)
|
||||
.subtext(format!("{} Prover", prover_kind)),
|
||||
)
|
||||
.tooltip(Tooltip::new().trigger(Trigger::Axis))
|
||||
.legend(Legend::new().orient(Orient::Vertical))
|
||||
.toolbox(
|
||||
Toolbox::new().show(true).feature(
|
||||
Feature::new()
|
||||
.save_as_image(SaveAsImage::new())
|
||||
.restore(Restore::new())
|
||||
.data_zoom(ToolboxDataZoom::new().y_axis_index("none"))
|
||||
.data_view(DataView::new().read_only(false)),
|
||||
),
|
||||
)
|
||||
.x_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Bandwidth (Mbps)")
|
||||
.name_location(NameLocation::Center),
|
||||
)
|
||||
.y_axis(
|
||||
Axis::new()
|
||||
.scale(true)
|
||||
.name("Runtime (s)")
|
||||
.name_location(NameLocation::Middle),
|
||||
)
|
||||
.series(
|
||||
Line::new()
|
||||
.name("Download bandwidth")
|
||||
.symbol_size(10)
|
||||
.data(download_data),
|
||||
)
|
||||
.series(
|
||||
Line::new()
|
||||
.name("Upload bandwidth (deferred decryption)")
|
||||
.symbol_size(10)
|
||||
.data(upload_deferred_data),
|
||||
)
|
||||
.series(
|
||||
Line::new()
|
||||
.name("Upload bandwidth")
|
||||
.symbol_size(10)
|
||||
.data(upload_non_deferred_data),
|
||||
);
|
||||
// Save the chart as HTML file.
|
||||
HtmlRenderer::new(TITLE, 1000, 800)
|
||||
.theme(THEME)
|
||||
.save(&chart, "runtime_vs_bandwidth.html")
|
||||
.unwrap();
|
||||
|
||||
Ok(chart)
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A Prover without memory profiling.
|
||||
|
||||
use tlsn_benches::prover_main::prover_main;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
prover_main(false).await
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//! A Prover with memory profiling.
|
||||
|
||||
use tlsn_benches::prover_main::prover_main;
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: dhat::Alloc = dhat::Alloc;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
if cfg!(feature = "browser-bench") {
|
||||
// Memory profiling is not compatible with browser benches.
|
||||
return Ok(());
|
||||
}
|
||||
prover_main(true).await
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! A Verifier without memory profiling.
|
||||
|
||||
use tlsn_benches::verifier_main::verifier_main;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
verifier_main(false).await
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
//! A Verifier with memory profiling.
|
||||
|
||||
use tlsn_benches::verifier_main::verifier_main;
|
||||
|
||||
#[global_allocator]
|
||||
static ALLOC: dhat::Alloc = dhat::Alloc;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
if cfg!(feature = "browser-bench") {
|
||||
// Memory profiling is not compatible with browser benches.
|
||||
return Ok(());
|
||||
}
|
||||
verifier_main(true).await
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
# Run the TLSN benches with Docker
|
||||
|
||||
In the root folder of this repository, run:
|
||||
```
|
||||
# Change to BENCH_TYPE=browser if you want benchmarks to run in the browser.
|
||||
docker build -t tlsn-bench . -f ./crates/benches/binary/benches.Dockerfile --build-arg BENCH_TYPE=native
|
||||
```
|
||||
|
||||
Next run the benches with:
|
||||
```
|
||||
docker run -it --privileged -v ./crates/benches/binary:/benches tlsn-bench
|
||||
```
|
||||
The `--privileged` parameter is required because this test bench needs permission to create networks with certain parameters
|
||||
@@ -1,123 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Field<T> {
|
||||
Single(T),
|
||||
Multiple(Vec<T>),
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Config {
|
||||
pub benches: Vec<Bench>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Bench {
|
||||
pub name: String,
|
||||
pub upload: Field<usize>,
|
||||
#[serde(rename = "upload-delay")]
|
||||
pub upload_delay: Field<usize>,
|
||||
pub download: Field<usize>,
|
||||
#[serde(rename = "download-delay")]
|
||||
pub download_delay: Field<usize>,
|
||||
#[serde(rename = "upload-size")]
|
||||
pub upload_size: Field<usize>,
|
||||
#[serde(rename = "download-size")]
|
||||
pub download_size: Field<usize>,
|
||||
#[serde(rename = "defer-decryption")]
|
||||
pub defer_decryption: Field<bool>,
|
||||
#[serde(rename = "memory-profile")]
|
||||
pub memory_profile: Field<bool>,
|
||||
}
|
||||
|
||||
impl Bench {
|
||||
/// Flattens the config into a list of instances
|
||||
pub fn flatten(self) -> Vec<BenchInstance> {
|
||||
let mut instances = vec![];
|
||||
|
||||
let upload = match self.upload {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let upload_delay = match self.upload_delay {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let download = match self.download {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let download_latency = match self.download_delay {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let upload_size = match self.upload_size {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let download_size = match self.download_size {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let defer_decryption = match self.defer_decryption {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
let memory_profile = match self.memory_profile {
|
||||
Field::Single(u) => vec![u],
|
||||
Field::Multiple(u) => u,
|
||||
};
|
||||
|
||||
for u in upload {
|
||||
for ul in &upload_delay {
|
||||
for d in &download {
|
||||
for dl in &download_latency {
|
||||
for us in &upload_size {
|
||||
for ds in &download_size {
|
||||
for dd in &defer_decryption {
|
||||
for mp in &memory_profile {
|
||||
instances.push(BenchInstance {
|
||||
name: self.name.clone(),
|
||||
upload: u,
|
||||
upload_delay: *ul,
|
||||
download: *d,
|
||||
download_delay: *dl,
|
||||
upload_size: *us,
|
||||
download_size: *ds,
|
||||
defer_decryption: *dd,
|
||||
memory_profile: *mp,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
instances
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct BenchInstance {
|
||||
pub name: String,
|
||||
pub upload: usize,
|
||||
pub upload_delay: usize,
|
||||
pub download: usize,
|
||||
pub download_delay: usize,
|
||||
pub upload_size: usize,
|
||||
pub download_size: usize,
|
||||
pub defer_decryption: bool,
|
||||
/// Whether this instance should be used for memory profiling.
|
||||
pub memory_profile: bool,
|
||||
}
|
||||
@@ -1,273 +0,0 @@
|
||||
pub mod config;
|
||||
pub mod metrics;
|
||||
mod preprocess;
|
||||
pub mod prover;
|
||||
pub mod prover_main;
|
||||
pub mod verifier_main;
|
||||
|
||||
use std::{
|
||||
io,
|
||||
process::{Command, Stdio},
|
||||
};
|
||||
|
||||
pub const PROVER_NAMESPACE: &str = "prover-ns";
|
||||
pub const PROVER_INTERFACE: &str = "prover-veth";
|
||||
pub const PROVER_SUBNET: &str = "10.10.1.0/24";
|
||||
pub const VERIFIER_NAMESPACE: &str = "verifier-ns";
|
||||
pub const VERIFIER_INTERFACE: &str = "verifier-veth";
|
||||
pub const VERIFIER_SUBNET: &str = "10.10.1.1/24";
|
||||
|
||||
pub fn set_up() -> io::Result<()> {
|
||||
// Create network namespaces
|
||||
create_network_namespace(PROVER_NAMESPACE)?;
|
||||
create_network_namespace(VERIFIER_NAMESPACE)?;
|
||||
|
||||
// Create veth pair and attach to namespaces
|
||||
create_veth_pair(
|
||||
PROVER_NAMESPACE,
|
||||
PROVER_INTERFACE,
|
||||
VERIFIER_NAMESPACE,
|
||||
VERIFIER_INTERFACE,
|
||||
)?;
|
||||
|
||||
// Set devices up
|
||||
set_device_up(PROVER_NAMESPACE, PROVER_INTERFACE)?;
|
||||
set_device_up(VERIFIER_NAMESPACE, VERIFIER_INTERFACE)?;
|
||||
|
||||
// Bring up the loopback interface.
|
||||
set_device_up(PROVER_NAMESPACE, "lo")?;
|
||||
set_device_up(VERIFIER_NAMESPACE, "lo")?;
|
||||
|
||||
// Assign IPs
|
||||
assign_ip_to_interface(PROVER_NAMESPACE, PROVER_INTERFACE, PROVER_SUBNET)?;
|
||||
assign_ip_to_interface(VERIFIER_NAMESPACE, VERIFIER_INTERFACE, VERIFIER_SUBNET)?;
|
||||
|
||||
// Set default routes
|
||||
set_default_route(
|
||||
PROVER_NAMESPACE,
|
||||
PROVER_INTERFACE,
|
||||
PROVER_SUBNET.split('/').next().unwrap(),
|
||||
)?;
|
||||
set_default_route(
|
||||
VERIFIER_NAMESPACE,
|
||||
VERIFIER_INTERFACE,
|
||||
VERIFIER_SUBNET.split('/').next().unwrap(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn clean_up() {
|
||||
// Delete interface pair
|
||||
if let Err(e) = Command::new("ip")
|
||||
.args([
|
||||
"netns",
|
||||
"exec",
|
||||
PROVER_NAMESPACE,
|
||||
"ip",
|
||||
"link",
|
||||
"delete",
|
||||
PROVER_INTERFACE,
|
||||
])
|
||||
.status()
|
||||
{
|
||||
println!("Error deleting interface {}: {}", PROVER_INTERFACE, e);
|
||||
}
|
||||
|
||||
// Delete namespaces
|
||||
if let Err(e) = Command::new("ip")
|
||||
.args(["netns", "del", PROVER_NAMESPACE])
|
||||
.status()
|
||||
{
|
||||
println!("Error deleting namespace {}: {}", PROVER_NAMESPACE, e);
|
||||
}
|
||||
|
||||
if let Err(e) = Command::new("ip")
|
||||
.args(["netns", "del", VERIFIER_NAMESPACE])
|
||||
.status()
|
||||
{
|
||||
println!("Error deleting namespace {}: {}", VERIFIER_NAMESPACE, e);
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the interface parameters.
|
||||
///
|
||||
/// Must be run in the correct namespace.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `egress` - The egress bandwidth in mbps.
|
||||
/// * `burst` - The burst in mbps.
|
||||
/// * `delay` - The delay in ms.
|
||||
pub fn set_interface(interface: &str, egress: usize, burst: usize, delay: usize) -> io::Result<()> {
|
||||
// Clear rules
|
||||
let output = Command::new("tc")
|
||||
.arg("qdisc")
|
||||
.arg("del")
|
||||
.arg("dev")
|
||||
.arg(interface)
|
||||
.arg("root")
|
||||
.stdout(Stdio::piped())
|
||||
.output()?;
|
||||
|
||||
if output.stderr == "Error: Cannot delete qdisc with handle of zero.\n".as_bytes() {
|
||||
// This error is informative, do not log it to stderr.
|
||||
} else if !output.status.success() {
|
||||
return Err(io::Error::other("Failed to clear rules"));
|
||||
}
|
||||
|
||||
// Egress
|
||||
Command::new("tc")
|
||||
.arg("qdisc")
|
||||
.arg("add")
|
||||
.arg("dev")
|
||||
.arg(interface)
|
||||
.arg("root")
|
||||
.arg("handle")
|
||||
.arg("1:")
|
||||
.arg("tbf")
|
||||
.arg("rate")
|
||||
.arg(format!("{}mbit", egress))
|
||||
.arg("burst")
|
||||
.arg(format!("{}mbit", burst))
|
||||
.arg("latency")
|
||||
.arg("60s")
|
||||
.status()?;
|
||||
|
||||
// Delay
|
||||
Command::new("tc")
|
||||
.arg("qdisc")
|
||||
.arg("add")
|
||||
.arg("dev")
|
||||
.arg(interface)
|
||||
.arg("parent")
|
||||
.arg("1:1")
|
||||
.arg("handle")
|
||||
.arg("10:")
|
||||
.arg("netem")
|
||||
.arg("delay")
|
||||
.arg(format!("{}ms", delay))
|
||||
.status()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a network namespace with the given name if it does not already exist.
|
||||
fn create_network_namespace(name: &str) -> io::Result<()> {
|
||||
// Check if namespace already exists
|
||||
if Command::new("ip")
|
||||
.args(["netns", "list"])
|
||||
.output()?
|
||||
.stdout
|
||||
.windows(name.len())
|
||||
.any(|ns| ns == name.as_bytes())
|
||||
{
|
||||
println!("Namespace {} already exists", name);
|
||||
return Ok(());
|
||||
} else {
|
||||
println!("Creating namespace {}", name);
|
||||
Command::new("ip").args(["netns", "add", name]).status()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_veth_pair(
|
||||
left_namespace: &str,
|
||||
left_interface: &str,
|
||||
right_namespace: &str,
|
||||
right_interface: &str,
|
||||
) -> io::Result<()> {
|
||||
// Check if interfaces are already present in namespaces
|
||||
if is_interface_present_in_namespace(left_namespace, left_interface)?
|
||||
|| is_interface_present_in_namespace(right_namespace, right_interface)?
|
||||
{
|
||||
println!("Virtual interface already exists.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Create veth pair
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"link",
|
||||
"add",
|
||||
left_interface,
|
||||
"type",
|
||||
"veth",
|
||||
"peer",
|
||||
"name",
|
||||
right_interface,
|
||||
])
|
||||
.status()?;
|
||||
|
||||
println!(
|
||||
"Created veth pair {} and {}",
|
||||
left_interface, right_interface
|
||||
);
|
||||
|
||||
// Attach veth pair to namespaces
|
||||
attach_interface_to_namespace(left_namespace, left_interface)?;
|
||||
attach_interface_to_namespace(right_namespace, right_interface)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn attach_interface_to_namespace(namespace: &str, interface: &str) -> io::Result<()> {
|
||||
Command::new("ip")
|
||||
.args(["link", "set", interface, "netns", namespace])
|
||||
.status()?;
|
||||
|
||||
println!("Attached {} to namespace {}", interface, namespace);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_default_route(namespace: &str, interface: &str, ip: &str) -> io::Result<()> {
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"netns", "exec", namespace, "ip", "route", "add", "default", "via", ip, "dev",
|
||||
interface,
|
||||
])
|
||||
.status()?;
|
||||
|
||||
println!(
|
||||
"Set default route for namespace {} ip {} to {}",
|
||||
namespace, ip, interface
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_interface_present_in_namespace(
|
||||
namespace: &str,
|
||||
interface: &str,
|
||||
) -> Result<bool, std::io::Error> {
|
||||
Ok(Command::new("ip")
|
||||
.args([
|
||||
"netns", "exec", namespace, "ip", "link", "list", "dev", interface,
|
||||
])
|
||||
.output()?
|
||||
.stdout
|
||||
.windows(interface.len())
|
||||
.any(|ns| ns == interface.as_bytes()))
|
||||
}
|
||||
|
||||
fn set_device_up(namespace: &str, interface: &str) -> io::Result<()> {
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"netns", "exec", namespace, "ip", "link", "set", interface, "up",
|
||||
])
|
||||
.status()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assign_ip_to_interface(namespace: &str, interface: &str, ip: &str) -> io::Result<()> {
|
||||
Command::new("ip")
|
||||
.args([
|
||||
"netns", "exec", namespace, "ip", "addr", "add", ip, "dev", interface,
|
||||
])
|
||||
.status()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tlsn_benches_library::ProverKind;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Metrics {
|
||||
pub name: String,
|
||||
/// The kind of the prover, either native or browser.
|
||||
pub kind: ProverKind,
|
||||
/// Upload bandwidth in Mbps.
|
||||
pub upload: usize,
|
||||
/// Upload latency in ms.
|
||||
pub upload_delay: usize,
|
||||
/// Download bandwidth in Mbps.
|
||||
pub download: usize,
|
||||
/// Download latency in ms.
|
||||
pub download_delay: usize,
|
||||
/// Total bytes sent to the server.
|
||||
pub upload_size: usize,
|
||||
/// Total bytes received from the server.
|
||||
pub download_size: usize,
|
||||
/// Whether deferred decryption was used.
|
||||
pub defer_decryption: bool,
|
||||
/// The total runtime of the benchmark in seconds.
|
||||
pub runtime: u64,
|
||||
/// The total amount of data uploaded to the verifier in bytes.
|
||||
pub uploaded: u64,
|
||||
/// The total amount of data downloaded from the verifier in bytes.
|
||||
pub downloaded: u64,
|
||||
/// The peak heap memory usage in bytes.
|
||||
pub heap_max_bytes: Option<usize>,
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
use hmac_sha256::build_circuits;
|
||||
|
||||
pub async fn preprocess_prf_circuits() {
|
||||
build_circuits().await;
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use tlsn_benches_library::{run_prover, AsyncIo, ProverKind, ProverTrait};
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub struct NativeProver {
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
defer_decryption: bool,
|
||||
io: Option<Box<dyn AsyncIo>>,
|
||||
client_conn: Option<Box<dyn AsyncIo>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProverTrait for NativeProver {
|
||||
async fn setup(
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
defer_decryption: bool,
|
||||
io: Box<dyn AsyncIo>,
|
||||
client_conn: Box<dyn AsyncIo>,
|
||||
) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
Ok(Self {
|
||||
upload_size,
|
||||
download_size,
|
||||
defer_decryption,
|
||||
io: Some(io),
|
||||
client_conn: Some(client_conn),
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(&mut self) -> anyhow::Result<u64> {
|
||||
let io = std::mem::take(&mut self.io).unwrap();
|
||||
let client_conn = std::mem::take(&mut self.client_conn).unwrap();
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
run_prover(
|
||||
self.upload_size,
|
||||
self.download_size,
|
||||
self.defer_decryption,
|
||||
io,
|
||||
client_conn,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Instant::now().duration_since(start_time).as_secs())
|
||||
}
|
||||
|
||||
fn kind(&self) -> ProverKind {
|
||||
ProverKind::Native
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
//! Contains the actual main() function of the prover binary. It is moved here
|
||||
//! in order to enable cargo to build two prover binaries - with and without
|
||||
//! memory profiling.
|
||||
|
||||
use std::{
|
||||
fs::metadata,
|
||||
io::Write,
|
||||
sync::{
|
||||
atomic::{AtomicU64, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
config::{BenchInstance, Config},
|
||||
metrics::Metrics,
|
||||
preprocess::preprocess_prf_circuits,
|
||||
set_interface, PROVER_INTERFACE,
|
||||
};
|
||||
use anyhow::Context;
|
||||
use tlsn_benches_library::{AsyncIo, ProverTrait};
|
||||
use tlsn_server_fixture::bind;
|
||||
|
||||
use csv::WriterBuilder;
|
||||
|
||||
use tokio_util::{
|
||||
compat::TokioAsyncReadCompatExt,
|
||||
io::{InspectReader, InspectWriter},
|
||||
};
|
||||
use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter};
|
||||
|
||||
#[cfg(not(feature = "browser-bench"))]
|
||||
use crate::prover::NativeProver as BenchProver;
|
||||
#[cfg(feature = "browser-bench")]
|
||||
use tlsn_benches_browser_native::BrowserProver as BenchProver;
|
||||
|
||||
pub async fn prover_main(is_memory_profiling: bool) -> anyhow::Result<()> {
|
||||
let config_path = std::env::var("CFG").unwrap_or_else(|_| "bench.toml".to_string());
|
||||
let config: Config = toml::from_str(
|
||||
&std::fs::read_to_string(config_path).context("failed to read config file")?,
|
||||
)
|
||||
.context("failed to parse config")?;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
|
||||
.init();
|
||||
|
||||
let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string());
|
||||
let port: u16 = std::env::var("VERIFIER_PORT")
|
||||
.map(|port| port.parse().expect("port is valid u16"))
|
||||
.unwrap_or(8000);
|
||||
let verifier_host = (ip.as_str(), port);
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open("metrics.csv")
|
||||
.context("failed to open metrics file")?;
|
||||
|
||||
// Preprocess the PRF circuits as they are allocating a lot of memory, which
|
||||
// don't need to be accounted for in the benchmarks.
|
||||
preprocess_prf_circuits().await;
|
||||
|
||||
{
|
||||
let mut metric_wrt = WriterBuilder::new()
|
||||
// If file is not empty, assume that the CSV header is already present in the file.
|
||||
.has_headers(metadata("metrics.csv")?.len() == 0)
|
||||
.from_writer(&mut file);
|
||||
for bench in config.benches {
|
||||
let instances = bench.flatten();
|
||||
for instance in instances {
|
||||
if is_memory_profiling && !instance.memory_profile {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!("{:?}", &instance);
|
||||
|
||||
let io = tokio::net::TcpStream::connect(verifier_host)
|
||||
.await
|
||||
.context("failed to open tcp connection")?;
|
||||
metric_wrt.serialize(
|
||||
run_instance(instance, io, is_memory_profiling)
|
||||
.await
|
||||
.context("failed to run instance")?,
|
||||
)?;
|
||||
metric_wrt.flush()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.flush()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_instance(
|
||||
instance: BenchInstance,
|
||||
io: impl AsyncIo,
|
||||
is_memory_profiling: bool,
|
||||
) -> anyhow::Result<Metrics> {
|
||||
let uploaded = Arc::new(AtomicU64::new(0));
|
||||
let downloaded = Arc::new(AtomicU64::new(0));
|
||||
let io = InspectWriter::new(
|
||||
InspectReader::new(io, {
|
||||
let downloaded = downloaded.clone();
|
||||
move |data| {
|
||||
downloaded.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}),
|
||||
{
|
||||
let uploaded = uploaded.clone();
|
||||
move |data| {
|
||||
uploaded.fetch_add(data.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let BenchInstance {
|
||||
name,
|
||||
upload,
|
||||
upload_delay,
|
||||
download,
|
||||
download_delay,
|
||||
upload_size,
|
||||
download_size,
|
||||
defer_decryption,
|
||||
memory_profile,
|
||||
} = instance.clone();
|
||||
|
||||
set_interface(PROVER_INTERFACE, upload, 1, upload_delay)?;
|
||||
|
||||
let _profiler = if is_memory_profiling {
|
||||
assert!(memory_profile, "Instance doesn't have `memory_profile` set");
|
||||
// Build a testing profiler as it won't output to stderr.
|
||||
Some(dhat::Profiler::builder().testing().build())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (client_conn, server_conn) = tokio::io::duplex(1 << 16);
|
||||
tokio::spawn(bind(server_conn.compat()));
|
||||
|
||||
let mut prover = BenchProver::setup(
|
||||
upload_size,
|
||||
download_size,
|
||||
defer_decryption,
|
||||
Box::new(io),
|
||||
Box::new(client_conn),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let runtime = prover.run().await?;
|
||||
|
||||
let heap_max_bytes = if is_memory_profiling {
|
||||
Some(dhat::HeapStats::get().max_bytes)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Metrics {
|
||||
name,
|
||||
kind: prover.kind(),
|
||||
upload,
|
||||
upload_delay,
|
||||
download,
|
||||
download_delay,
|
||||
upload_size,
|
||||
download_size,
|
||||
defer_decryption,
|
||||
runtime,
|
||||
uploaded: uploaded.load(Ordering::SeqCst),
|
||||
downloaded: downloaded.load(Ordering::SeqCst),
|
||||
heap_max_bytes,
|
||||
})
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
//! Contains the actual main() function of the verifier binary. It is moved here
|
||||
//! in order to enable cargo to build two verifier binaries - with and without
|
||||
//! memory profiling.
|
||||
|
||||
use crate::{
|
||||
config::{BenchInstance, Config},
|
||||
preprocess::preprocess_prf_circuits,
|
||||
set_interface, VERIFIER_INTERFACE,
|
||||
};
|
||||
use tls_core::verify::WebPkiVerifier;
|
||||
use tlsn_common::config::ProtocolConfigValidator;
|
||||
use tlsn_core::CryptoProvider;
|
||||
use tlsn_server_fixture_certs::CA_CERT_DER;
|
||||
use tlsn_verifier::{Verifier, VerifierConfig};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter};
|
||||
|
||||
pub async fn verifier_main(is_memory_profiling: bool) -> anyhow::Result<()> {
|
||||
let config_path = std::env::var("CFG").unwrap_or_else(|_| "bench.toml".to_string());
|
||||
let config: Config = toml::from_str(
|
||||
&std::fs::read_to_string(config_path).context("failed to read config file")?,
|
||||
)
|
||||
.context("failed to parse config")?;
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
|
||||
.init();
|
||||
|
||||
let ip = std::env::var("VERIFIER_IP").unwrap_or_else(|_| "10.10.1.1".to_string());
|
||||
let port: u16 = std::env::var("VERIFIER_PORT")
|
||||
.map(|port| port.parse().expect("port is valid u16"))
|
||||
.unwrap_or(8000);
|
||||
let host = (ip.as_str(), port);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(host)
|
||||
.await
|
||||
.context("failed to bind to port")?;
|
||||
|
||||
// Preprocess the PRF circuits as they are allocating a lot of memory, which
|
||||
// don't need to be accounted for in the benchmarks.
|
||||
preprocess_prf_circuits().await;
|
||||
|
||||
for bench in config.benches {
|
||||
for instance in bench.flatten() {
|
||||
if is_memory_profiling && !instance.memory_profile {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (io, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.context("failed to accept connection")?;
|
||||
run_instance(instance, io, is_memory_profiling)
|
||||
.await
|
||||
.context("failed to run instance")?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_instance<S: AsyncWrite + AsyncRead + Send + Sync + Unpin + 'static>(
|
||||
instance: BenchInstance,
|
||||
io: S,
|
||||
is_memory_profiling: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let BenchInstance {
|
||||
download,
|
||||
download_delay,
|
||||
upload_size,
|
||||
download_size,
|
||||
memory_profile,
|
||||
..
|
||||
} = instance;
|
||||
|
||||
set_interface(VERIFIER_INTERFACE, download, 1, download_delay)?;
|
||||
|
||||
let _profiler = if is_memory_profiling {
|
||||
assert!(memory_profile, "Instance doesn't have `memory_profile` set");
|
||||
// Build a testing profiler as it won't output to stderr.
|
||||
Some(dhat::Profiler::builder().testing().build())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let provider = CryptoProvider {
|
||||
cert: cert_verifier(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config_validator = ProtocolConfigValidator::builder()
|
||||
.max_sent_data(upload_size + 256)
|
||||
.max_recv_data(download_size + 256)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let verifier = Verifier::new(
|
||||
VerifierConfig::builder()
|
||||
.protocol_config_validator(config_validator)
|
||||
.crypto_provider(provider)
|
||||
.build()?,
|
||||
);
|
||||
|
||||
verifier.verify(io.compat()).await?;
|
||||
|
||||
println!("verifier done");
|
||||
|
||||
if is_memory_profiling {
|
||||
// XXX: we may want to profile the Verifier's memory usage at a future
|
||||
// point.
|
||||
// println!(
|
||||
// "verifier peak heap memory usage: {}",
|
||||
// dhat::HeapStats::get().max_bytes
|
||||
// );
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cert_verifier() -> WebPkiVerifier {
|
||||
let mut root_store = tls_core::anchors::RootCertStore::empty();
|
||||
root_store
|
||||
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
|
||||
.unwrap();
|
||||
|
||||
WebPkiVerifier::new(root_store, None)
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "tlsn-benches-browser-core"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
|
||||
[dependencies]
|
||||
tlsn-benches-library = { workspace = true }
|
||||
|
||||
serio = { workspace = true }
|
||||
|
||||
serde = { workspace = true }
|
||||
tokio-util= { workspace = true, features = ["compat", "io-util"] }
|
||||
@@ -1,68 +0,0 @@
|
||||
//! Contains core types shared by the native and the wasm components.
|
||||
|
||||
use std::{
|
||||
io::Error,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use tlsn_benches_library::AsyncIo;
|
||||
|
||||
use serio::{
|
||||
codec::{Bincode, Framed},
|
||||
Sink, Stream,
|
||||
};
|
||||
use tokio_util::codec::LengthDelimitedCodec;
|
||||
|
||||
pub mod msg;
|
||||
|
||||
/// A sink/stream for serializable types with a framed transport.
|
||||
pub struct FramedIo {
|
||||
inner:
|
||||
serio::Framed<tokio_util::codec::Framed<Box<dyn AsyncIo>, LengthDelimitedCodec>, Bincode>,
|
||||
}
|
||||
|
||||
impl FramedIo {
|
||||
/// Creates a new `FramedIo` from the given async `io`.
|
||||
#[allow(clippy::default_constructed_unit_structs)]
|
||||
pub fn new(io: Box<dyn AsyncIo>) -> Self {
|
||||
let io = LengthDelimitedCodec::builder().new_framed(io);
|
||||
Self {
|
||||
inner: Framed::new(io, Bincode::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink for FramedIo {
|
||||
type Error = Error;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_ready(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_close(cx)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn start_send<Item: serio::Serialize>(
|
||||
mut self: Pin<&mut Self>,
|
||||
item: Item,
|
||||
) -> std::result::Result<(), Self::Error> {
|
||||
Pin::new(&mut self.inner).start_send(item)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for FramedIo {
|
||||
type Error = Error;
|
||||
|
||||
fn poll_next<Item: serio::Deserialize>(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Item, Error>>> {
|
||||
Pin::new(&mut self.inner).poll_next(cx)
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//! Messages exchanged by the native and the wasm components of the browser
|
||||
//! prover.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq)]
|
||||
/// The config sent to the wasm component.
|
||||
pub struct Config {
|
||||
pub upload_size: usize,
|
||||
pub download_size: usize,
|
||||
pub defer_decryption: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, PartialEq)]
|
||||
/// Sent by the wasm component when proving process is finished. Contains total
|
||||
/// runtime in seconds.
|
||||
pub struct Runtime(pub u64);
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "tlsn-benches-browser-native"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
|
||||
[dependencies]
|
||||
tlsn-benches-browser-core = { workspace = true }
|
||||
tlsn-benches-library = { workspace = true }
|
||||
|
||||
serio = { workspace = true }
|
||||
websocket-relay = { workspace = true }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
chromiumoxide = { version = "0.6.0" , features = ["tokio-runtime"] }
|
||||
futures = { workspace = true }
|
||||
rust-embed = "8.5.0"
|
||||
tokio = { workspace = true, features = ["rt", "io-std"] }
|
||||
tracing = { workspace = true }
|
||||
warp = "0.3.7"
|
||||
warp-embed = "0.5.0"
|
||||
@@ -1,331 +0,0 @@
|
||||
//! Contains the native component of the browser prover.
|
||||
//!
|
||||
//! Conceptually the browser prover consists of the native and the wasm
|
||||
//! components. The native component is responsible for starting the browser,
|
||||
//! loading the wasm component and driving it.
|
||||
|
||||
use std::{env, net::IpAddr};
|
||||
|
||||
use serio::{stream::IoStreamExt, SinkExt as _};
|
||||
use tlsn_benches_browser_core::{
|
||||
msg::{Config, Runtime},
|
||||
FramedIo,
|
||||
};
|
||||
use tlsn_benches_library::{AsyncIo, ProverKind, ProverTrait};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use chromiumoxide::{
|
||||
cdp::{
|
||||
browser_protocol::log::{EventEntryAdded, LogEntryLevel},
|
||||
js_protocol::runtime::EventExceptionThrown,
|
||||
},
|
||||
Browser, BrowserConfig, Page,
|
||||
};
|
||||
use futures::{Future, FutureExt, StreamExt};
|
||||
use rust_embed::RustEmbed;
|
||||
use tokio::{io, io::AsyncWriteExt, net::TcpListener, task::JoinHandle};
|
||||
use tracing::{debug, error, info};
|
||||
use warp::Filter;
|
||||
|
||||
/// The IP on which the wasm component is served.
|
||||
pub static DEFAULT_WASM_IP: &str = "127.0.0.1";
|
||||
/// The IP of the websocket relay.
|
||||
pub static DEFAULT_WS_IP: &str = "127.0.0.1";
|
||||
|
||||
/// The port on which the wasm component is served.
|
||||
pub static DEFAULT_WASM_PORT: u16 = 9001;
|
||||
/// The port of the websocket relay.
|
||||
pub static DEFAULT_WS_PORT: u16 = 9002;
|
||||
/// The port for the wasm component to communicate with the TLS server.
|
||||
pub static DEFAULT_WASM_TO_SERVER_PORT: u16 = 9003;
|
||||
/// The port for the wasm component to communicate with the verifier.
|
||||
pub static DEFAULT_WASM_TO_VERIFIER_PORT: u16 = 9004;
|
||||
/// The port for the wasm component to communicate with the native component.
|
||||
pub static DEFAULT_WASM_TO_NATIVE_PORT: u16 = 9005;
|
||||
|
||||
// The `pkg` dir will be embedded into the binary at compile-time.
|
||||
#[derive(RustEmbed)]
|
||||
#[folder = "../wasm/pkg"]
|
||||
struct Data;
|
||||
|
||||
/// The native component of the prover which runs in the browser.
|
||||
pub struct BrowserProver {
|
||||
/// Io for communication with the wasm component.
|
||||
wasm_io: FramedIo,
|
||||
/// The browser spawned by the prover.
|
||||
browser: Browser,
|
||||
/// A handle to the http server.
|
||||
http_server: JoinHandle<()>,
|
||||
/// Handles to the relays.
|
||||
relays: Vec<JoinHandle<Result<(), anyhow::Error>>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProverTrait for BrowserProver {
|
||||
async fn setup(
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
defer_decryption: bool,
|
||||
verifier_io: Box<dyn AsyncIo>,
|
||||
server_io: Box<dyn AsyncIo>,
|
||||
) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
let wasm_port: u16 = env::var("WASM_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_WASM_PORT);
|
||||
let ws_port: u16 = env::var("WS_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_WS_PORT);
|
||||
let wasm_to_server_port: u16 = env::var("WASM_TO_SERVER_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_WASM_TO_SERVER_PORT);
|
||||
let wasm_to_verifier_port: u16 = env::var("WASM_TO_VERIFIER_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_WASM_TO_VERIFIER_PORT);
|
||||
let wasm_to_native_port: u16 = env::var("WASM_TO_NATIVE_PORT")
|
||||
.map(|port| port.parse().expect("port should be valid integer"))
|
||||
.unwrap_or(DEFAULT_WASM_TO_NATIVE_PORT);
|
||||
|
||||
let wasm_ip: IpAddr = env::var("WASM_IP")
|
||||
.map(|addr| addr.parse().expect("should be valid IP address"))
|
||||
.unwrap_or(IpAddr::V4(DEFAULT_WASM_IP.parse().unwrap()));
|
||||
let ws_ip: IpAddr = env::var("WS_IP")
|
||||
.map(|addr| addr.parse().expect("should be valid IP address"))
|
||||
.unwrap_or(IpAddr::V4(DEFAULT_WS_IP.parse().unwrap()));
|
||||
|
||||
let mut relays = Vec::with_capacity(4);
|
||||
|
||||
relays.push(spawn_websocket_relay(ws_ip, ws_port).await?);
|
||||
|
||||
let http_server = spawn_http_server(wasm_ip, wasm_port)?;
|
||||
|
||||
// Relay data from the wasm component to the server.
|
||||
relays.push(spawn_port_relay(wasm_to_server_port, server_io).await?);
|
||||
|
||||
// Relay data from the wasm component to the verifier.
|
||||
relays.push(spawn_port_relay(wasm_to_verifier_port, verifier_io).await?);
|
||||
|
||||
// Create a framed connection to the wasm component.
|
||||
let (wasm_left, wasm_right) = tokio::io::duplex(1 << 16);
|
||||
|
||||
relays.push(spawn_port_relay(wasm_to_native_port, Box::new(wasm_right)).await?);
|
||||
let mut wasm_io = FramedIo::new(Box::new(wasm_left));
|
||||
|
||||
info!("spawning browser");
|
||||
|
||||
// Note that the browser must be spawned only when the WebSocket relay is
|
||||
// running.
|
||||
let browser = spawn_browser(
|
||||
wasm_ip,
|
||||
ws_ip,
|
||||
wasm_port,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port,
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!("sending config to the browser component");
|
||||
|
||||
wasm_io
|
||||
.send(Config {
|
||||
upload_size,
|
||||
download_size,
|
||||
defer_decryption,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
wasm_io,
|
||||
browser,
|
||||
http_server,
|
||||
relays,
|
||||
})
|
||||
}
|
||||
|
||||
async fn run(&mut self) -> anyhow::Result<u64> {
|
||||
let runtime: Runtime = self.wasm_io.expect_next().await.unwrap();
|
||||
|
||||
_ = self.clean_up().await?;
|
||||
|
||||
Ok(runtime.0)
|
||||
}
|
||||
|
||||
fn kind(&self) -> ProverKind {
|
||||
ProverKind::Browser
|
||||
}
|
||||
}
|
||||
|
||||
impl BrowserProver {
|
||||
async fn clean_up(&mut self) -> anyhow::Result<()> {
|
||||
// Kill the http server.
|
||||
self.http_server.abort();
|
||||
|
||||
// Kill all relays.
|
||||
let _ = self
|
||||
.relays
|
||||
.iter_mut()
|
||||
.map(|task| task.abort())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Close the browser.
|
||||
self.browser.close().await?;
|
||||
self.browser.wait().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn spawn_websocket_relay(
|
||||
ip: IpAddr,
|
||||
port: u16,
|
||||
) -> anyhow::Result<JoinHandle<Result<(), anyhow::Error>>> {
|
||||
let listener = TcpListener::bind((ip, port)).await?;
|
||||
Ok(tokio::spawn(websocket_relay::run(listener)))
|
||||
}
|
||||
|
||||
/// Binds to the given localhost `port`, accepts a connection and relays data
|
||||
/// between the connection and the `channel`.
|
||||
pub async fn spawn_port_relay(
|
||||
port: u16,
|
||||
channel: Box<dyn AsyncIo>,
|
||||
) -> anyhow::Result<JoinHandle<Result<(), anyhow::Error>>> {
|
||||
let listener = tokio::net::TcpListener::bind(("127.0.0.1", port))
|
||||
.await
|
||||
.context("failed to bind to port")?;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let (tcp, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.context("failed to accept a connection")
|
||||
.unwrap();
|
||||
|
||||
relay_data(Box::new(tcp), channel).await
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Relays data between two sources.
|
||||
pub async fn relay_data(left: Box<dyn AsyncIo>, right: Box<dyn AsyncIo>) -> Result<()> {
|
||||
let (mut left_read, mut left_write) = io::split(left);
|
||||
let (mut right_read, mut right_write) = io::split(right);
|
||||
|
||||
let left_to_right = async {
|
||||
io::copy(&mut left_read, &mut right_write).await?;
|
||||
right_write.shutdown().await
|
||||
};
|
||||
|
||||
let right_to_left = async {
|
||||
io::copy(&mut right_read, &mut left_write).await?;
|
||||
left_write.shutdown().await
|
||||
};
|
||||
|
||||
tokio::try_join!(left_to_right, right_to_left)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Spawns the browser and starts the wasm component.
|
||||
async fn spawn_browser(
|
||||
wasm_ip: IpAddr,
|
||||
ws_ip: IpAddr,
|
||||
wasm_port: u16,
|
||||
ws_port: u16,
|
||||
wasm_to_server_port: u16,
|
||||
wasm_to_verifier_port: u16,
|
||||
wasm_to_native_port: u16,
|
||||
) -> anyhow::Result<Browser> {
|
||||
// Chrome requires --no-sandbox when running as root.
|
||||
let config = BrowserConfig::builder()
|
||||
.no_sandbox()
|
||||
.incognito()
|
||||
.build()
|
||||
.map_err(|s| anyhow!(s))?;
|
||||
|
||||
debug!("launching chromedriver");
|
||||
|
||||
let (browser, mut handler) = Browser::launch(config).await?;
|
||||
|
||||
debug!("chromedriver started");
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(res) = handler.next().await {
|
||||
res.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let page = browser
|
||||
.new_page(&format!("http://{}:{}/index.html", wasm_ip, wasm_port))
|
||||
.await?;
|
||||
|
||||
tokio::spawn(register_listeners(&page).await?);
|
||||
|
||||
page.wait_for_navigation().await?;
|
||||
// Note that `format!` needs double {{ }} in order to escape them.
|
||||
let _ = page
|
||||
.evaluate_function(&format!(
|
||||
r#"
|
||||
async function() {{
|
||||
await window.worker.init();
|
||||
// Do not `await` run() or else it will block the browser.
|
||||
window.worker.run("{}", {}, {}, {}, {});
|
||||
}}
|
||||
"#,
|
||||
ws_ip, ws_port, wasm_to_server_port, wasm_to_verifier_port, wasm_to_native_port
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(browser)
|
||||
}
|
||||
|
||||
pub fn spawn_http_server(ip: IpAddr, port: u16) -> anyhow::Result<JoinHandle<()>> {
|
||||
let handle = tokio::spawn(async move {
|
||||
// Serve embedded files with additional headers.
|
||||
let data_serve = warp_embed::embed(&Data);
|
||||
|
||||
let data_serve_with_headers = data_serve
|
||||
.map(|reply| {
|
||||
warp::reply::with_header(reply, "Cross-Origin-Opener-Policy", "same-origin")
|
||||
})
|
||||
.map(|reply| {
|
||||
warp::reply::with_header(reply, "Cross-Origin-Embedder-Policy", "require-corp")
|
||||
});
|
||||
|
||||
warp::serve(data_serve_with_headers).run((ip, port)).await;
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn register_listeners(page: &Page) -> Result<impl Future<Output = ()>> {
|
||||
let mut logs = page.event_listener::<EventEntryAdded>().await?.fuse();
|
||||
let mut exceptions = page.event_listener::<EventExceptionThrown>().await?.fuse();
|
||||
|
||||
Ok(futures::future::join(
|
||||
async move {
|
||||
while let Some(event) = logs.next().await {
|
||||
let entry = &event.entry;
|
||||
match entry.level {
|
||||
LogEntryLevel::Error => {
|
||||
error!("{:?}", entry);
|
||||
}
|
||||
_ => {
|
||||
debug!("{:?}: {}", entry.timestamp, entry.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
async move {
|
||||
while let Some(event) = exceptions.next().await {
|
||||
error!("{:?}", event);
|
||||
}
|
||||
},
|
||||
)
|
||||
.map(|_| ()))
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "tlsn-benches-browser-wasm"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
|
||||
[lib]
|
||||
crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
tlsn-benches-browser-core = { workspace = true }
|
||||
tlsn-benches-library = { workspace = true }
|
||||
tlsn-wasm = { path = "../../../wasm" }
|
||||
|
||||
serio = { workspace = true }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
wasm-bindgen = { version = "0.2.87" }
|
||||
wasm-bindgen-futures = { version = "0.4.37" }
|
||||
web-time = { workspace = true }
|
||||
# Use the patched ws_stream_wasm to fix the issue https://github.com/najamelan/ws_stream_wasm/issues/12#issuecomment-1711902958
|
||||
ws_stream_wasm = { version = "0.7.4", git = "https://github.com/tlsnotary/ws_stream_wasm", rev = "2ed12aad9f0236e5321f577672f309920b2aef51", features = [
|
||||
"tokio_io",
|
||||
] }
|
||||
|
||||
[package.metadata.wasm-pack.profile.release]
|
||||
# Note: these wasm-pack options should match those in crates/wasm/Cargo.toml
|
||||
opt-level = "z"
|
||||
wasm-opt = true
|
||||
@@ -1,7 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
</head>
|
||||
<body>
|
||||
<script src="index.js" type="module"></script>
|
||||
</body>
|
||||
</html>
|
||||
@@ -1,7 +0,0 @@
|
||||
import * as Comlink from "./comlink.mjs";
|
||||
|
||||
async function init() {
|
||||
const worker = Comlink.wrap(new Worker("worker.js", { type: "module" }));
|
||||
window.worker = worker;
|
||||
}
|
||||
init();
|
||||
@@ -1,45 +0,0 @@
|
||||
import * as Comlink from "./comlink.mjs";
|
||||
|
||||
import init, { wasm_main, initialize } from './tlsn_benches_browser_wasm.js';
|
||||
|
||||
class Worker {
|
||||
async init() {
|
||||
try {
|
||||
await init();
|
||||
// Tracing may interfere with the benchmark results. We should enable it only for debugging.
|
||||
// init_logging({
|
||||
// level: 'Debug',
|
||||
// crate_filters: undefined,
|
||||
// span_events: undefined,
|
||||
// });
|
||||
await initialize({ thread_count: navigator.hardwareConcurrency });
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
async run(
|
||||
ws_ip,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port
|
||||
) {
|
||||
try {
|
||||
await wasm_main(
|
||||
ws_ip,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const worker = new Worker();
|
||||
|
||||
Comlink.expose(worker);
|
||||
@@ -1,2 +0,0 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
||||
@@ -1,102 +0,0 @@
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
//! Contains the wasm component of the browser prover.
|
||||
//!
|
||||
//! Conceptually the browser prover consists of the native and the wasm
|
||||
//! components.
|
||||
|
||||
use serio::{stream::IoStreamExt, SinkExt as _};
|
||||
use tlsn_benches_browser_core::{
|
||||
msg::{Config, Runtime},
|
||||
FramedIo,
|
||||
};
|
||||
use tlsn_benches_library::run_prover;
|
||||
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_time::Instant;
|
||||
use ws_stream_wasm::WsMeta;
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub async fn wasm_main(
|
||||
ws_ip: String,
|
||||
ws_port: u16,
|
||||
wasm_to_server_port: u16,
|
||||
wasm_to_verifier_port: u16,
|
||||
wasm_to_native_port: u16,
|
||||
) -> Result<(), JsError> {
|
||||
// Wrapping main() since wasm_bindgen doesn't support anyhow.
|
||||
main(
|
||||
ws_ip,
|
||||
ws_port,
|
||||
wasm_to_server_port,
|
||||
wasm_to_verifier_port,
|
||||
wasm_to_native_port,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| JsError::new(&err.to_string()))
|
||||
}
|
||||
|
||||
pub async fn main(
|
||||
ws_ip: String,
|
||||
ws_port: u16,
|
||||
wasm_to_server_port: u16,
|
||||
wasm_to_verifier_port: u16,
|
||||
wasm_to_native_port: u16,
|
||||
) -> Result<()> {
|
||||
info!("starting main");
|
||||
|
||||
// Connect to the server.
|
||||
let (_, server_io_ws) = WsMeta::connect(
|
||||
&format!(
|
||||
"ws://{}:{}/tcp?addr=localhost%3A{}",
|
||||
ws_ip, ws_port, wasm_to_server_port
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let server_io = server_io_ws.into_io();
|
||||
|
||||
// Connect to the verifier.
|
||||
let (_, verifier_io_ws) = WsMeta::connect(
|
||||
&format!(
|
||||
"ws://{}:{}/tcp?addr=localhost%3A{}",
|
||||
ws_ip, ws_port, wasm_to_verifier_port
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let verifier_io = verifier_io_ws.into_io();
|
||||
|
||||
// Connect to the native component of the browser prover.
|
||||
let (_, native_io_ws) = WsMeta::connect(
|
||||
&format!(
|
||||
"ws://{}:{}/tcp?addr=localhost%3A{}",
|
||||
ws_ip, ws_port, wasm_to_native_port
|
||||
),
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
let mut native_io = FramedIo::new(Box::new(native_io_ws.into_io()));
|
||||
|
||||
info!("expecting config from the native component");
|
||||
|
||||
let cfg: Config = native_io.expect_next().await?;
|
||||
|
||||
let start_time = Instant::now();
|
||||
run_prover(
|
||||
cfg.upload_size,
|
||||
cfg.download_size,
|
||||
cfg.defer_decryption,
|
||||
Box::new(verifier_io),
|
||||
Box::new(server_io),
|
||||
)
|
||||
.await?;
|
||||
|
||||
native_io
|
||||
.send(Runtime(start_time.elapsed().as_secs()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "tlsn-benches-library"
|
||||
publish = false
|
||||
version = "0.0.0"
|
||||
|
||||
[dependencies]
|
||||
tlsn-common = { workspace = true }
|
||||
tlsn-core = { workspace = true }
|
||||
tlsn-prover = { workspace = true }
|
||||
tlsn-server-fixture-certs = { workspace = true }
|
||||
tlsn-tls-core = { workspace = true }
|
||||
|
||||
anyhow = "1.0"
|
||||
async-trait = "0.1.81"
|
||||
futures = { version = "0.3", features = ["compat"] }
|
||||
serde = { workspace = true }
|
||||
tokio = {version = "1", default-features = false, features = ["rt", "macros"]}
|
||||
tokio-util= {version = "0.7", features = ["compat", "io"]}
|
||||
@@ -1,133 +0,0 @@
|
||||
use tls_core::{anchors::RootCertStore, verify::WebPkiVerifier};
|
||||
use tlsn_common::config::ProtocolConfig;
|
||||
use tlsn_core::{transcript::Idx, CryptoProvider};
|
||||
use tlsn_prover::{Prover, ProverConfig};
|
||||
use tlsn_server_fixture_certs::{CA_CERT_DER, SERVER_DOMAIN};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use futures::{future::try_join, AsyncReadExt as _, AsyncWriteExt as _, TryFutureExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_util::compat::TokioAsyncReadCompatExt;
|
||||
|
||||
pub trait AsyncIo: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
|
||||
impl<T> AsyncIo for T where T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static {}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ProverTrait {
|
||||
/// Sets up the prover preparing it to be run. Returns a prover ready to be
|
||||
/// run.
|
||||
async fn setup(
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
defer_decryption: bool,
|
||||
verifier_io: Box<dyn AsyncIo>,
|
||||
server_io: Box<dyn AsyncIo>,
|
||||
) -> anyhow::Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Runs the prover. Returns the total run time in seconds.
|
||||
async fn run(&mut self) -> anyhow::Result<u64>;
|
||||
|
||||
/// Returns the kind of the prover.
|
||||
fn kind(&self) -> ProverKind;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
/// The kind of a prover.
|
||||
pub enum ProverKind {
|
||||
/// The prover compiled into a native binary.
|
||||
Native,
|
||||
/// The prover compiled into a wasm binary.
|
||||
Browser,
|
||||
}
|
||||
|
||||
impl From<ProverKind> for String {
|
||||
fn from(value: ProverKind) -> Self {
|
||||
match value {
|
||||
ProverKind::Native => "Native".to_string(),
|
||||
ProverKind::Browser => "Browser".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_prover(
|
||||
upload_size: usize,
|
||||
download_size: usize,
|
||||
defer_decryption: bool,
|
||||
io: Box<dyn AsyncIo>,
|
||||
client_conn: Box<dyn AsyncIo>,
|
||||
) -> anyhow::Result<()> {
|
||||
let provider = CryptoProvider {
|
||||
cert: WebPkiVerifier::new(root_store(), None),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let protocol_config = if defer_decryption {
|
||||
ProtocolConfig::builder()
|
||||
.max_sent_data(upload_size + 256)
|
||||
.max_recv_data(download_size + 256)
|
||||
.build()
|
||||
.unwrap()
|
||||
} else {
|
||||
ProtocolConfig::builder()
|
||||
.max_sent_data(upload_size + 256)
|
||||
.max_recv_data(download_size + 256)
|
||||
.max_recv_data_online(download_size + 256)
|
||||
.build()
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
let prover = Prover::new(
|
||||
ProverConfig::builder()
|
||||
.server_name(SERVER_DOMAIN)
|
||||
.protocol_config(protocol_config)
|
||||
.defer_decryption_from_start(defer_decryption)
|
||||
.crypto_provider(provider)
|
||||
.build()
|
||||
.context("invalid prover config")?,
|
||||
)
|
||||
.setup(io.compat())
|
||||
.await?;
|
||||
|
||||
let (mut mpc_tls_connection, prover_fut) = prover.connect(client_conn.compat()).await?;
|
||||
let tls_fut = async move {
|
||||
let request = format!(
|
||||
"GET /bytes?size={} HTTP/1.1\r\nConnection: close\r\nData: {}\r\n\r\n",
|
||||
download_size,
|
||||
String::from_utf8(vec![0x42u8; upload_size]).unwrap(),
|
||||
);
|
||||
|
||||
mpc_tls_connection.write_all(request.as_bytes()).await?;
|
||||
mpc_tls_connection.close().await?;
|
||||
|
||||
let mut response = vec![];
|
||||
mpc_tls_connection.read_to_end(&mut response).await?;
|
||||
|
||||
dbg!(response.len());
|
||||
|
||||
Ok::<(), anyhow::Error>(())
|
||||
};
|
||||
|
||||
let (prover_task, _) = try_join(prover_fut.map_err(anyhow::Error::from), tls_fut).await?;
|
||||
|
||||
let mut prover = prover_task.start_prove();
|
||||
|
||||
let (sent_len, recv_len) = prover.transcript().len();
|
||||
prover
|
||||
.prove_transcript(Idx::new(0..sent_len), Idx::new(0..recv_len))
|
||||
.await?;
|
||||
prover.finalize().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn root_store() -> RootCertStore {
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store
|
||||
.add(&tls_core::key::Certificate(CA_CERT_DER.to_vec()))
|
||||
.unwrap();
|
||||
root_store
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
[package]
|
||||
name = "tlsn-common"
|
||||
description = "Common code shared between tlsn-prover and tlsn-verifier"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -14,14 +17,17 @@ tlsn-cipher = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
mpz-memory-core = { workspace = true }
|
||||
mpz-hash = { workspace = true }
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-zk = { workspace = true }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
derive_builder = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
ghash = { workspace = true }
|
||||
once_cell = { workspace = true }
|
||||
opaque-debug = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rangeset = { workspace = true }
|
||||
serio = { workspace = true, features = ["codec", "bincode"] }
|
||||
thiserror = { workspace = true }
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
//! Plaintext commitment and proof of encryption.
|
||||
|
||||
pub mod hash;
|
||||
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{binary::Binary, DecodeFutureTyped};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
|
||||
use crate::{
|
||||
transcript::Record,
|
||||
zk_aes::{ZkAesCtr, ZkAesCtrError},
|
||||
zk_aes_ctr::{ZkAesCtr, ZkAesCtrError},
|
||||
Role,
|
||||
};
|
||||
|
||||
|
||||
197
crates/common/src/commit/hash.rs
Normal file
197
crates/common/src/commit/hash.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
//! Plaintext hash commitments.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
DecodeFutureTyped, MemoryExt, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm, VmError};
|
||||
use tlsn_core::{
|
||||
hash::{Blinder, Hash, HashAlgId, TypedHash},
|
||||
transcript::{
|
||||
hash::{PlaintextHash, PlaintextHashSecret},
|
||||
Direction, Idx,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{transcript::TranscriptRefs, Role};
|
||||
|
||||
/// Future which will resolve to the committed hash values.
|
||||
#[derive(Debug)]
|
||||
pub struct HashCommitFuture {
|
||||
#[allow(clippy::type_complexity)]
|
||||
futs: Vec<(
|
||||
Direction,
|
||||
Idx,
|
||||
HashAlgId,
|
||||
DecodeFutureTyped<BitVec, Vec<u8>>,
|
||||
)>,
|
||||
}
|
||||
|
||||
impl HashCommitFuture {
|
||||
/// Tries to receive the value, returning an error if the value is not
|
||||
/// ready.
|
||||
pub fn try_recv(self) -> Result<Vec<PlaintextHash>, HashCommitError> {
|
||||
let mut output = Vec::new();
|
||||
for (direction, idx, alg, mut fut) in self.futs {
|
||||
let hash = fut
|
||||
.try_recv()
|
||||
.map_err(|_| HashCommitError::decode())?
|
||||
.ok_or_else(HashCommitError::decode)?;
|
||||
output.push(PlaintextHash {
|
||||
direction,
|
||||
idx,
|
||||
hash: TypedHash {
|
||||
alg,
|
||||
value: Hash::try_from(hash).map_err(HashCommitError::convert)?,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
/// Prove plaintext hash commitments.
|
||||
pub fn prove_hash(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
refs: &TranscriptRefs,
|
||||
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
|
||||
) -> Result<(HashCommitFuture, Vec<PlaintextHashSecret>), HashCommitError> {
|
||||
let mut futs = Vec::new();
|
||||
let mut secrets = Vec::new();
|
||||
for (direction, idx, alg, hash_ref, blinder_ref) in
|
||||
hash_commit_inner(vm, Role::Prover, refs, idxs)?
|
||||
{
|
||||
let blinder: Blinder = rand::random();
|
||||
|
||||
vm.assign(blinder_ref, blinder.as_bytes().to_vec())?;
|
||||
vm.commit(blinder_ref)?;
|
||||
|
||||
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
|
||||
|
||||
futs.push((direction, idx.clone(), alg, hash_fut));
|
||||
secrets.push(PlaintextHashSecret {
|
||||
direction,
|
||||
idx,
|
||||
blinder,
|
||||
alg,
|
||||
});
|
||||
}
|
||||
|
||||
Ok((HashCommitFuture { futs }, secrets))
|
||||
}
|
||||
|
||||
/// Verify plaintext hash commitments.
|
||||
pub fn verify_hash(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
refs: &TranscriptRefs,
|
||||
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
|
||||
) -> Result<HashCommitFuture, HashCommitError> {
|
||||
let mut futs = Vec::new();
|
||||
for (direction, idx, alg, hash_ref, blinder_ref) in
|
||||
hash_commit_inner(vm, Role::Verifier, refs, idxs)?
|
||||
{
|
||||
vm.commit(blinder_ref)?;
|
||||
|
||||
let hash_fut = vm.decode(Vector::<U8>::from(hash_ref))?;
|
||||
|
||||
futs.push((direction, idx, alg, hash_fut));
|
||||
}
|
||||
|
||||
Ok(HashCommitFuture { futs })
|
||||
}
|
||||
|
||||
/// Commit plaintext hashes of the transcript.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn hash_commit_inner(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
role: Role,
|
||||
refs: &TranscriptRefs,
|
||||
idxs: impl IntoIterator<Item = (Direction, Idx, HashAlgId)>,
|
||||
) -> Result<Vec<(Direction, Idx, HashAlgId, Array<U8, 32>, Vector<U8>)>, HashCommitError> {
|
||||
let mut output = Vec::new();
|
||||
let mut hashers = HashMap::new();
|
||||
for (direction, idx, alg) in idxs {
|
||||
let blinder = vm.alloc_vec::<U8>(16)?;
|
||||
match role {
|
||||
Role::Prover => vm.mark_private(blinder)?,
|
||||
Role::Verifier => vm.mark_blind(blinder)?,
|
||||
}
|
||||
|
||||
let hash = match alg {
|
||||
HashAlgId::SHA256 => {
|
||||
let mut hasher = if let Some(hasher) = hashers.get(&alg).cloned() {
|
||||
hasher
|
||||
} else {
|
||||
let hasher = Sha256::new_with_init(vm).map_err(HashCommitError::hasher)?;
|
||||
hashers.insert(alg, hasher.clone());
|
||||
hasher
|
||||
};
|
||||
|
||||
for plaintext in refs.get(direction, &idx).expect("plaintext refs are valid") {
|
||||
hasher.update(&plaintext);
|
||||
}
|
||||
hasher.update(&blinder);
|
||||
hasher.finalize(vm).map_err(HashCommitError::hasher)?
|
||||
}
|
||||
alg => {
|
||||
return Err(HashCommitError::unsupported_alg(alg));
|
||||
}
|
||||
};
|
||||
|
||||
output.push((direction, idx, alg, hash, blinder));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Error type for hash commitments.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct HashCommitError(#[from] ErrorRepr);
|
||||
|
||||
impl HashCommitError {
|
||||
fn decode() -> Self {
|
||||
Self(ErrorRepr::Decode)
|
||||
}
|
||||
|
||||
fn convert(e: &'static str) -> Self {
|
||||
Self(ErrorRepr::Convert(e))
|
||||
}
|
||||
|
||||
fn hasher<E>(e: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
{
|
||||
Self(ErrorRepr::Hasher(e.into()))
|
||||
}
|
||||
|
||||
fn unsupported_alg(alg: HashAlgId) -> Self {
|
||||
Self(ErrorRepr::UnsupportedAlg { alg })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("hash commit error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("VM error: {0}")]
|
||||
Vm(VmError),
|
||||
#[error("failed to decode hash")]
|
||||
Decode,
|
||||
#[error("failed to convert hash: {0}")]
|
||||
Convert(&'static str),
|
||||
#[error("unsupported hash algorithm: {alg}")]
|
||||
UnsupportedAlg { alg: HashAlgId },
|
||||
#[error("hasher error: {0}")]
|
||||
Hasher(Box<dyn std::error::Error + Send + Sync>),
|
||||
}
|
||||
|
||||
impl From<VmError> for HashCommitError {
|
||||
fn from(value: VmError) -> Self {
|
||||
Self(ErrorRepr::Vm(value))
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,10 @@ use std::error::Error;
|
||||
|
||||
// Default is 32 bytes to decrypt the TLS protocol messages.
|
||||
const DEFAULT_MAX_RECV_ONLINE: usize = 32;
|
||||
// Default maximum number of TLS records to allow.
|
||||
//
|
||||
// This would allow for up to 50Mb upload from prover to verifier.
|
||||
const DEFAULT_RECORDS_LIMIT: usize = 256;
|
||||
|
||||
// Current version that is running.
|
||||
static VERSION: Lazy<Version> = Lazy::new(|| {
|
||||
@@ -21,12 +25,26 @@ static VERSION: Lazy<Version> = Lazy::new(|| {
|
||||
pub struct ProtocolConfig {
|
||||
/// Maximum number of bytes that can be sent.
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of application data records that can be sent.
|
||||
#[builder(setter(strip_option), default)]
|
||||
max_sent_records: Option<usize>,
|
||||
/// Maximum number of bytes that can be decrypted online, i.e. while the
|
||||
/// MPC-TLS connection is active.
|
||||
#[builder(default = "DEFAULT_MAX_RECV_ONLINE")]
|
||||
max_recv_data_online: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
/// Maximum number of received application data records that can be
|
||||
/// decrypted online, i.e. while the MPC-TLS connection is active.
|
||||
#[builder(setter(strip_option), default)]
|
||||
max_recv_records_online: Option<usize>,
|
||||
/// Whether the `deferred decryption` feature is toggled on from the start
|
||||
/// of the MPC-TLS connection.
|
||||
#[builder(default = "true")]
|
||||
defer_decryption_from_start: bool,
|
||||
/// Network settings.
|
||||
#[builder(default)]
|
||||
network: NetworkSetting,
|
||||
/// Version that is being run by prover/verifier.
|
||||
#[builder(setter(skip), default = "VERSION.clone()")]
|
||||
version: Version,
|
||||
@@ -54,6 +72,12 @@ impl ProtocolConfig {
|
||||
self.max_sent_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be sent.
|
||||
pub fn max_sent_records(&self) -> Option<usize> {
|
||||
self.max_sent_records
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be decrypted online.
|
||||
pub fn max_recv_data_online(&self) -> usize {
|
||||
self.max_recv_data_online
|
||||
@@ -63,6 +87,23 @@ impl ProtocolConfig {
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of received application data records that
|
||||
/// can be decrypted online.
|
||||
pub fn max_recv_records_online(&self) -> Option<usize> {
|
||||
self.max_recv_records_online
|
||||
}
|
||||
|
||||
/// Returns whether the `deferred decryption` feature is toggled on from the
|
||||
/// start of the MPC-TLS connection.
|
||||
pub fn defer_decryption_from_start(&self) -> bool {
|
||||
self.defer_decryption_from_start
|
||||
}
|
||||
|
||||
/// Returns the network settings.
|
||||
pub fn network(&self) -> NetworkSetting {
|
||||
self.network
|
||||
}
|
||||
}
|
||||
|
||||
/// Protocol configuration validator used by checker (i.e. verifier) to perform
|
||||
@@ -71,8 +112,14 @@ impl ProtocolConfig {
|
||||
pub struct ProtocolConfigValidator {
|
||||
/// Maximum number of bytes that can be sent.
|
||||
max_sent_data: usize,
|
||||
/// Maximum number of application data records that can be sent.
|
||||
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
|
||||
max_sent_records: usize,
|
||||
/// Maximum number of bytes that can be received.
|
||||
max_recv_data: usize,
|
||||
/// Maximum number of application data records that can be received online.
|
||||
#[builder(default = "DEFAULT_RECORDS_LIMIT")]
|
||||
max_recv_records_online: usize,
|
||||
/// Version that is being run by checker.
|
||||
#[builder(setter(skip), default = "VERSION.clone()")]
|
||||
version: Version,
|
||||
@@ -89,15 +136,28 @@ impl ProtocolConfigValidator {
|
||||
self.max_sent_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be sent.
|
||||
pub fn max_sent_records(&self) -> usize {
|
||||
self.max_sent_records
|
||||
}
|
||||
|
||||
/// Returns the maximum number of bytes that can be received.
|
||||
pub fn max_recv_data(&self) -> usize {
|
||||
self.max_recv_data
|
||||
}
|
||||
|
||||
/// Returns the maximum number of application data records that can
|
||||
/// be received online.
|
||||
pub fn max_recv_records_online(&self) -> usize {
|
||||
self.max_recv_records_online
|
||||
}
|
||||
|
||||
/// Performs compatibility check of the protocol configuration between
|
||||
/// prover and verifier.
|
||||
pub fn validate(&self, config: &ProtocolConfig) -> Result<(), ProtocolConfigError> {
|
||||
self.check_max_transcript_size(config.max_sent_data, config.max_recv_data)?;
|
||||
self.check_max_records(config.max_sent_records, config.max_recv_records_online)?;
|
||||
self.check_version(&config.version)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -125,6 +185,32 @@ impl ProtocolConfigValidator {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_max_records(
|
||||
&self,
|
||||
max_sent_records: Option<usize>,
|
||||
max_recv_records_online: Option<usize>,
|
||||
) -> Result<(), ProtocolConfigError> {
|
||||
if let Some(max_sent_records) = max_sent_records {
|
||||
if max_sent_records > self.max_sent_records {
|
||||
return Err(ProtocolConfigError::max_record_count(format!(
|
||||
"max_sent_records {} is greater than the configured limit {}",
|
||||
max_sent_records, self.max_sent_records,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(max_recv_records_online) = max_recv_records_online {
|
||||
if max_recv_records_online > self.max_recv_records_online {
|
||||
return Err(ProtocolConfigError::max_record_count(format!(
|
||||
"max_recv_records_online {} is greater than the configured limit {}",
|
||||
max_recv_records_online, self.max_recv_records_online,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Checks if both versions are the same (might support check for different but
|
||||
// compatible versions in the future).
|
||||
fn check_version(&self, peer_version: &Version) -> Result<(), ProtocolConfigError> {
|
||||
@@ -139,6 +225,24 @@ impl ProtocolConfigValidator {
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings for the network environment.
|
||||
///
|
||||
/// Provides optimization options to adapt the protocol to different network
|
||||
/// situations.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum NetworkSetting {
|
||||
/// Prefers a bandwidth-heavy protocol.
|
||||
Bandwidth,
|
||||
/// Prefers a latency-heavy protocol.
|
||||
Latency,
|
||||
}
|
||||
|
||||
impl Default for NetworkSetting {
|
||||
fn default() -> Self {
|
||||
Self::Bandwidth
|
||||
}
|
||||
}
|
||||
|
||||
/// A ProtocolConfig error.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub struct ProtocolConfigError {
|
||||
@@ -165,6 +269,13 @@ impl ProtocolConfigError {
|
||||
}
|
||||
}
|
||||
|
||||
fn max_record_count(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::MaxRecordCount,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn version(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Version,
|
||||
@@ -176,7 +287,8 @@ impl ProtocolConfigError {
|
||||
impl fmt::Display for ProtocolConfigError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.kind {
|
||||
ErrorKind::MaxTranscriptSize => write!(f, "max transcript size error")?,
|
||||
ErrorKind::MaxTranscriptSize => write!(f, "max transcript size exceeded")?,
|
||||
ErrorKind::MaxRecordCount => write!(f, "max record count exceeded")?,
|
||||
ErrorKind::Version => write!(f, "version error")?,
|
||||
}
|
||||
|
||||
@@ -191,6 +303,7 @@ impl fmt::Display for ProtocolConfigError {
|
||||
#[derive(Debug)]
|
||||
enum ErrorKind {
|
||||
MaxTranscriptSize,
|
||||
MaxRecordCount,
|
||||
Version,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,29 @@
|
||||
//! Encoding commitment protocol.
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_common::Context;
|
||||
use mpz_core::Block;
|
||||
use mpz_memory_core::{
|
||||
binary::U8,
|
||||
correlated::{Delta, Key, Mac},
|
||||
Vector,
|
||||
};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serio::{stream::IoStreamExt, SinkExt};
|
||||
use tlsn_core::transcript::{
|
||||
encoding::{new_encoder, Encoder, EncoderSecret, EncodingProvider},
|
||||
Direction, Idx,
|
||||
use tlsn_core::{
|
||||
hash::HashAlgorithm,
|
||||
transcript::{
|
||||
encoding::{
|
||||
new_encoder, Encoder, EncoderSecret, EncodingCommitment, EncodingProvider,
|
||||
EncodingProviderError, EncodingTree, EncodingTreeError,
|
||||
},
|
||||
Direction, Idx,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::transcript::TranscriptRefs;
|
||||
|
||||
/// Bytes of encoding, per byte.
|
||||
const ENCODING_SIZE: usize = 128;
|
||||
|
||||
@@ -19,35 +34,49 @@ struct Encodings {
|
||||
}
|
||||
|
||||
/// Transfers the encodings using the provided seed and keys.
|
||||
pub async fn transfer(
|
||||
///
|
||||
/// The keys must be consistent with the global delta used in the encodings.
|
||||
pub async fn transfer<'a>(
|
||||
ctx: &mut Context,
|
||||
secret: &EncoderSecret,
|
||||
sent_keys: impl IntoIterator<Item = &'_ Block>,
|
||||
recv_keys: impl IntoIterator<Item = &'_ Block>,
|
||||
) -> Result<(), EncodingError> {
|
||||
let encoder = new_encoder(secret);
|
||||
refs: &TranscriptRefs,
|
||||
delta: &Delta,
|
||||
f: impl Fn(Vector<U8>) -> &'a [Key],
|
||||
) -> Result<EncodingCommitment, EncodingError> {
|
||||
let secret = EncoderSecret::new(rand::rng().random(), delta.as_block().to_bytes());
|
||||
let encoder = new_encoder(&secret);
|
||||
|
||||
let sent_keys: Vec<u8> = sent_keys
|
||||
.into_iter()
|
||||
.flat_map(|key| key.as_bytes())
|
||||
let sent_keys: Vec<u8> = refs
|
||||
.sent()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|key| key.as_block().as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_keys: Vec<u8> = recv_keys
|
||||
.into_iter()
|
||||
.flat_map(|key| key.as_bytes())
|
||||
let recv_keys: Vec<u8> = refs
|
||||
.recv()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|key| key.as_block().as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
assert_eq!(sent_keys.len() % ENCODING_SIZE, 0);
|
||||
assert_eq!(recv_keys.len() % ENCODING_SIZE, 0);
|
||||
|
||||
let mut sent_encoding = encoder.encode_idx(
|
||||
let mut sent_encoding = Vec::with_capacity(sent_keys.len());
|
||||
let mut recv_encoding = Vec::with_capacity(recv_keys.len());
|
||||
|
||||
encoder.encode_range(
|
||||
Direction::Sent,
|
||||
&Idx::new(0..sent_keys.len() / ENCODING_SIZE),
|
||||
0..sent_keys.len() / ENCODING_SIZE,
|
||||
&mut sent_encoding,
|
||||
);
|
||||
let mut recv_encoding = encoder.encode_idx(
|
||||
encoder.encode_range(
|
||||
Direction::Received,
|
||||
&Idx::new(0..recv_keys.len() / ENCODING_SIZE),
|
||||
0..recv_keys.len() / ENCODING_SIZE,
|
||||
&mut recv_encoding,
|
||||
);
|
||||
|
||||
sent_encoding
|
||||
@@ -66,24 +95,40 @@ pub async fn transfer(
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
let root = ctx.io_mut().expect_next().await?;
|
||||
ctx.io_mut().send(secret.clone()).await?;
|
||||
|
||||
Ok(EncodingCommitment {
|
||||
root,
|
||||
secret: secret.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Receives the encodings using the provided MACs.
|
||||
pub async fn receive(
|
||||
///
|
||||
/// The MACs must be consistent with the global delta used in the encodings.
|
||||
pub async fn receive<'a>(
|
||||
ctx: &mut Context,
|
||||
sent_macs: impl IntoIterator<Item = &'_ Block>,
|
||||
recv_macs: impl IntoIterator<Item = &'_ Block>,
|
||||
) -> Result<impl EncodingProvider, EncodingError> {
|
||||
hasher: &(dyn HashAlgorithm + Send + Sync),
|
||||
refs: &TranscriptRefs,
|
||||
f: impl Fn(Vector<U8>) -> &'a [Mac],
|
||||
idxs: impl IntoIterator<Item = &(Direction, Idx)>,
|
||||
) -> Result<(EncodingCommitment, EncodingTree), EncodingError> {
|
||||
let Encodings { mut sent, mut recv } = ctx.io_mut().expect_next().await?;
|
||||
|
||||
let sent_macs: Vec<u8> = sent_macs
|
||||
.into_iter()
|
||||
let sent_macs: Vec<u8> = refs
|
||||
.sent()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
let recv_macs: Vec<u8> = recv_macs
|
||||
.into_iter()
|
||||
let recv_macs: Vec<u8> = refs
|
||||
.recv()
|
||||
.iter()
|
||||
.copied()
|
||||
.flat_map(&f)
|
||||
.flat_map(|mac| mac.as_bytes())
|
||||
.copied()
|
||||
.collect();
|
||||
@@ -116,7 +161,17 @@ pub async fn receive(
|
||||
.zip(recv_macs)
|
||||
.for_each(|(enc, mac)| *enc ^= mac);
|
||||
|
||||
Ok(Provider { sent, recv })
|
||||
let provider = Provider { sent, recv };
|
||||
|
||||
let tree = EncodingTree::new(hasher, idxs, &provider)?;
|
||||
let root = tree.root();
|
||||
|
||||
ctx.io_mut().send(root.clone()).await?;
|
||||
let secret = ctx.io_mut().expect_next().await?;
|
||||
|
||||
let commitment = EncodingCommitment { root, secret };
|
||||
|
||||
Ok((commitment, tree))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -126,25 +181,27 @@ struct Provider {
|
||||
}
|
||||
|
||||
impl EncodingProvider for Provider {
|
||||
fn provide_encoding(&self, direction: Direction, idx: &Idx) -> Option<Vec<u8>> {
|
||||
fn provide_encoding(
|
||||
&self,
|
||||
direction: Direction,
|
||||
range: Range<usize>,
|
||||
dest: &mut Vec<u8>,
|
||||
) -> Result<(), EncodingProviderError> {
|
||||
let encodings = match direction {
|
||||
Direction::Sent => &self.sent,
|
||||
Direction::Received => &self.recv,
|
||||
};
|
||||
|
||||
let mut encoding = Vec::with_capacity(idx.len());
|
||||
for range in idx.iter_ranges() {
|
||||
let start = range.start * ENCODING_SIZE;
|
||||
let end = range.end * ENCODING_SIZE;
|
||||
let start = range.start * ENCODING_SIZE;
|
||||
let end = range.end * ENCODING_SIZE;
|
||||
|
||||
if end > encodings.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
encoding.extend_from_slice(&encodings[start..end]);
|
||||
if end > encodings.len() {
|
||||
return Err(EncodingProviderError);
|
||||
}
|
||||
|
||||
Some(encoding)
|
||||
dest.extend_from_slice(&encodings[start..end]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,6 +221,8 @@ enum ErrorRepr {
|
||||
expected: usize,
|
||||
got: usize,
|
||||
},
|
||||
#[error("encoding tree error: {0}")]
|
||||
EncodingTree(EncodingTreeError),
|
||||
}
|
||||
|
||||
impl From<std::io::Error> for EncodingError {
|
||||
@@ -171,3 +230,9 @@ impl From<std::io::Error> for EncodingError {
|
||||
Self(ErrorRepr::Io(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EncodingTreeError> for EncodingError {
|
||||
fn from(value: EncodingTreeError) -> Self {
|
||||
Self(ErrorRepr::EncodingTree(value))
|
||||
}
|
||||
}
|
||||
|
||||
39
crates/common/src/ghash.rs
Normal file
39
crates/common/src/ghash.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
//! GHASH methods.
|
||||
|
||||
// This module belongs in tls/core. It was moved out here temporarily.
|
||||
|
||||
use ghash::{
|
||||
universal_hash::{KeyInit, UniversalHash as UniversalHashReference},
|
||||
GHash,
|
||||
};
|
||||
|
||||
/// Computes a GHASH tag.
|
||||
pub fn ghash(aad: &[u8], ciphertext: &[u8], key: &[u8; 16]) -> [u8; 16] {
|
||||
let mut ghash = GHash::new(key.into());
|
||||
ghash.update_padded(&build_ghash_data(aad.to_vec(), ciphertext.to_owned()));
|
||||
let out = ghash.finalize();
|
||||
out.into()
|
||||
}
|
||||
|
||||
/// Builds padded data for GHASH.
|
||||
pub fn build_ghash_data(mut aad: Vec<u8>, mut ciphertext: Vec<u8>) -> Vec<u8> {
|
||||
let associated_data_bitlen = (aad.len() as u64) * 8;
|
||||
let text_bitlen = (ciphertext.len() as u64) * 8;
|
||||
|
||||
let len_block = ((associated_data_bitlen as u128) << 64) + (text_bitlen as u128);
|
||||
|
||||
// Pad data to be a multiple of 16 bytes.
|
||||
let aad_padded_block_count = (aad.len() / 16) + (aad.len() % 16 != 0) as usize;
|
||||
aad.resize(aad_padded_block_count * 16, 0);
|
||||
|
||||
let ciphertext_padded_block_count =
|
||||
(ciphertext.len() / 16) + (ciphertext.len() % 16 != 0) as usize;
|
||||
ciphertext.resize(ciphertext_padded_block_count * 16, 0);
|
||||
|
||||
let mut data: Vec<u8> = Vec::with_capacity(aad.len() + ciphertext.len() + 16);
|
||||
data.extend(aad);
|
||||
data.extend(ciphertext);
|
||||
data.extend_from_slice(&len_block.to_be_bytes());
|
||||
|
||||
data
|
||||
}
|
||||
@@ -8,10 +8,12 @@ pub mod commit;
|
||||
pub mod config;
|
||||
pub mod context;
|
||||
pub mod encoding;
|
||||
pub mod ghash;
|
||||
pub mod msg;
|
||||
pub mod mux;
|
||||
pub mod tag;
|
||||
pub mod transcript;
|
||||
pub mod zk_aes;
|
||||
pub mod zk_aes_ctr;
|
||||
|
||||
/// The party's role in the TLSN protocol.
|
||||
///
|
||||
|
||||
@@ -72,7 +72,7 @@ pub fn attach_mux<T: AsyncWrite + AsyncRead + Send + Unpin + 'static>(
|
||||
role: Role,
|
||||
) -> (MuxFuture, MuxControl) {
|
||||
let mut mux_config = yamux::Config::default();
|
||||
mux_config.set_max_num_streams(32);
|
||||
mux_config.set_max_num_streams(36);
|
||||
|
||||
let mux_role = match role {
|
||||
Role::Prover => yamux::Mode::Client,
|
||||
|
||||
157
crates/common/src/tag.rs
Normal file
157
crates/common/src/tag.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
//! TLS record tag verification.
|
||||
|
||||
use crate::{ghash::ghash, transcript::Record};
|
||||
use cipher::{aes::Aes128, Cipher};
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
DecodeFutureTyped,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, Vm};
|
||||
use tls_core::cipher::make_tls12_aad;
|
||||
|
||||
/// Proves the verification of tags of the given `records`,
|
||||
/// returning a proof.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `key_iv` - Cipher key and IV.
|
||||
/// * `mac_key` - MAC key.
|
||||
/// * `records` - Records for which the verification is to be proven.
|
||||
pub fn verify_tags(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
key_iv: (Array<U8, 16>, Array<U8, 4>),
|
||||
mac_key: Array<U8, 16>,
|
||||
records: Vec<Record>,
|
||||
) -> Result<TagProof, TagProofError> {
|
||||
let mut aes = Aes128::default();
|
||||
aes.set_key(key_iv.0);
|
||||
aes.set_iv(key_iv.1);
|
||||
|
||||
// Compute j0 blocks.
|
||||
let j0s = records
|
||||
.iter()
|
||||
.map(|rec| {
|
||||
let block = aes.alloc_ctr_block(vm).map_err(TagProofError::vm)?;
|
||||
|
||||
let explicit_nonce: [u8; 8] =
|
||||
rec.explicit_nonce
|
||||
.clone()
|
||||
.try_into()
|
||||
.map_err(|explicit_nonce: Vec<_>| ErrorRepr::ExplicitNonceLength {
|
||||
expected: 8,
|
||||
actual: explicit_nonce.len(),
|
||||
})?;
|
||||
|
||||
vm.assign(block.explicit_nonce, explicit_nonce)
|
||||
.map_err(TagProofError::vm)?;
|
||||
vm.commit(block.explicit_nonce).map_err(TagProofError::vm)?;
|
||||
|
||||
// j0's counter is set to 1.
|
||||
vm.assign(block.counter, 1u32.to_be_bytes())
|
||||
.map_err(TagProofError::vm)?;
|
||||
vm.commit(block.counter).map_err(TagProofError::vm)?;
|
||||
|
||||
let j0 = vm.decode(block.output).map_err(TagProofError::vm)?;
|
||||
|
||||
Ok(j0)
|
||||
})
|
||||
.collect::<Result<Vec<_>, TagProofError>>()?;
|
||||
|
||||
let mac_key = vm.decode(mac_key).map_err(TagProofError::vm)?;
|
||||
|
||||
Ok(TagProof {
|
||||
j0s,
|
||||
records,
|
||||
mac_key,
|
||||
})
|
||||
}
|
||||
|
||||
/// Proof of tag verification.
|
||||
#[derive(Debug)]
|
||||
#[must_use]
|
||||
pub struct TagProof {
|
||||
/// The j0 block for each record.
|
||||
j0s: Vec<DecodeFutureTyped<BitVec, [u8; 16]>>,
|
||||
records: Vec<Record>,
|
||||
/// The MAC key for tag computation.
|
||||
mac_key: DecodeFutureTyped<BitVec, [u8; 16]>,
|
||||
}
|
||||
|
||||
impl TagProof {
|
||||
/// Verifies the proof.
|
||||
pub fn verify(self) -> Result<(), TagProofError> {
|
||||
let Self {
|
||||
j0s,
|
||||
mut mac_key,
|
||||
records,
|
||||
} = self;
|
||||
|
||||
let mac_key = mac_key
|
||||
.try_recv()
|
||||
.map_err(TagProofError::vm)?
|
||||
.ok_or_else(|| ErrorRepr::NotDecoded)?;
|
||||
|
||||
for (mut j0, rec) in j0s.into_iter().zip(records) {
|
||||
let j0 = j0
|
||||
.try_recv()
|
||||
.map_err(TagProofError::vm)?
|
||||
.ok_or_else(|| ErrorRepr::NotDecoded)?;
|
||||
|
||||
let aad = make_tls12_aad(rec.seq, rec.typ, rec.version, rec.ciphertext.len());
|
||||
|
||||
let ghash_tag = ghash(aad.as_ref(), &rec.ciphertext, &mac_key);
|
||||
|
||||
let record_tag = match rec.tag.as_ref() {
|
||||
Some(tag) => tag,
|
||||
None => {
|
||||
// This will never happen, since we only call this method
|
||||
// for proofs where the records' tags are known.
|
||||
return Err(ErrorRepr::UnknownTag.into());
|
||||
}
|
||||
};
|
||||
|
||||
if *record_tag
|
||||
!= ghash_tag
|
||||
.into_iter()
|
||||
.zip(j0.into_iter())
|
||||
.map(|(a, b)| a ^ b)
|
||||
.collect::<Vec<_>>()
|
||||
{
|
||||
return Err(ErrorRepr::InvalidTag.into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`J0Proof`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct TagProofError(#[from] ErrorRepr);
|
||||
|
||||
impl TagProofError {
|
||||
fn vm<E>(err: E) -> Self
|
||||
where
|
||||
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
|
||||
{
|
||||
Self(ErrorRepr::Vm(err.into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("j0 proof error: {0}")]
|
||||
enum ErrorRepr {
|
||||
#[error("value was not decoded")]
|
||||
NotDecoded,
|
||||
#[error("VM error: {0}")]
|
||||
Vm(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
#[error("tag does not match expected")]
|
||||
InvalidTag,
|
||||
#[error("tag is not known")]
|
||||
UnknownTag,
|
||||
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
|
||||
ExplicitNonceLength { expected: usize, actual: usize },
|
||||
}
|
||||
@@ -1,22 +1,26 @@
|
||||
//! TLS transcript.
|
||||
|
||||
use mpz_memory_core::{binary::U8, Vector};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
MemoryExt, Vector,
|
||||
};
|
||||
use mpz_vm_core::{Vm, VmError};
|
||||
use rangeset::Intersection;
|
||||
use tls_core::msgs::enums::ContentType;
|
||||
use tlsn_core::transcript::{Direction, Idx, Transcript};
|
||||
use tls_core::msgs::enums::{ContentType, ProtocolVersion};
|
||||
use tlsn_core::transcript::{Direction, Idx, PartialTranscript, Transcript};
|
||||
|
||||
/// A transcript of sent and received TLS records.
|
||||
/// A transcript of TLS records sent and received by the prover.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct TlsTranscript {
|
||||
/// Records sent by the prover.
|
||||
/// Sent records.
|
||||
pub sent: Vec<Record>,
|
||||
/// Records received by the prover.
|
||||
/// Received records.
|
||||
pub recv: Vec<Record>,
|
||||
}
|
||||
|
||||
impl TlsTranscript {
|
||||
/// Returns the application data transcript.
|
||||
pub fn to_transcript(&self) -> Result<Transcript, IncompleteTranscript> {
|
||||
pub fn to_transcript(&self) -> Result<Transcript, TlsTranscriptError> {
|
||||
let mut sent = Vec::new();
|
||||
let mut recv = Vec::new();
|
||||
|
||||
@@ -28,7 +32,7 @@ impl TlsTranscript {
|
||||
let plaintext = record
|
||||
.plaintext
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?
|
||||
.ok_or(ErrorRepr::IncompleteTranscript {})?
|
||||
.clone();
|
||||
sent.extend_from_slice(&plaintext);
|
||||
}
|
||||
@@ -41,7 +45,7 @@ impl TlsTranscript {
|
||||
let plaintext = record
|
||||
.plaintext
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?
|
||||
.ok_or(ErrorRepr::IncompleteTranscript {})?
|
||||
.clone();
|
||||
recv.extend_from_slice(&plaintext);
|
||||
}
|
||||
@@ -50,7 +54,7 @@ impl TlsTranscript {
|
||||
}
|
||||
|
||||
/// Returns the application data transcript references.
|
||||
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, IncompleteTranscript> {
|
||||
pub fn to_transcript_refs(&self) -> Result<TranscriptRefs, TlsTranscriptError> {
|
||||
let mut sent = Vec::new();
|
||||
let mut recv = Vec::new();
|
||||
|
||||
@@ -62,7 +66,7 @@ impl TlsTranscript {
|
||||
let plaintext_ref = record
|
||||
.plaintext_ref
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?;
|
||||
.ok_or(ErrorRepr::IncompleteTranscript {})?;
|
||||
sent.push(*plaintext_ref);
|
||||
}
|
||||
|
||||
@@ -74,7 +78,7 @@ impl TlsTranscript {
|
||||
let plaintext_ref = record
|
||||
.plaintext_ref
|
||||
.as_ref()
|
||||
.ok_or(IncompleteTranscript {})?;
|
||||
.ok_or(ErrorRepr::IncompleteTranscript {})?;
|
||||
recv.push(*plaintext_ref);
|
||||
}
|
||||
|
||||
@@ -97,6 +101,10 @@ pub struct Record {
|
||||
pub explicit_nonce: Vec<u8>,
|
||||
/// Ciphertext.
|
||||
pub ciphertext: Vec<u8>,
|
||||
/// Tag.
|
||||
pub tag: Option<Vec<u8>>,
|
||||
/// Version.
|
||||
pub version: ProtocolVersion,
|
||||
}
|
||||
|
||||
opaque_debug::implement!(Record);
|
||||
@@ -163,10 +171,80 @@ impl TranscriptRefs {
|
||||
}
|
||||
}
|
||||
|
||||
/// Error for [`TranscriptRefs::from_transcript`].
|
||||
/// Error for [`TlsTranscript`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("not all application plaintext was committed to in the TLS transcript")]
|
||||
pub struct IncompleteTranscript {}
|
||||
#[error(transparent)]
|
||||
pub struct TlsTranscriptError(#[from] ErrorRepr);
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("TLS transcript error")]
|
||||
enum ErrorRepr {
|
||||
#[error("not all application plaintext was committed to in the TLS transcript")]
|
||||
IncompleteTranscript {},
|
||||
}
|
||||
|
||||
/// Decodes the transcript.
|
||||
pub fn decode_transcript(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
sent: &Idx,
|
||||
recv: &Idx,
|
||||
refs: &TranscriptRefs,
|
||||
) -> Result<(), VmError> {
|
||||
let sent_refs = refs.get(Direction::Sent, sent).expect("index is in bounds");
|
||||
let recv_refs = refs
|
||||
.get(Direction::Received, recv)
|
||||
.expect("index is in bounds");
|
||||
|
||||
for slice in sent_refs.into_iter().chain(recv_refs) {
|
||||
// Drop the future, we don't need it.
|
||||
drop(vm.decode(slice)?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verifies a partial transcript.
|
||||
pub fn verify_transcript(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
transcript: &PartialTranscript,
|
||||
refs: &TranscriptRefs,
|
||||
) -> Result<(), InconsistentTranscript> {
|
||||
let sent_refs = refs
|
||||
.get(Direction::Sent, transcript.sent_authed())
|
||||
.expect("index is in bounds");
|
||||
let recv_refs = refs
|
||||
.get(Direction::Received, transcript.received_authed())
|
||||
.expect("index is in bounds");
|
||||
|
||||
let mut authenticated_data = Vec::new();
|
||||
for data in sent_refs.into_iter().chain(recv_refs) {
|
||||
let plaintext = vm
|
||||
.get(data)
|
||||
.expect("reference is valid")
|
||||
.expect("plaintext is decoded");
|
||||
authenticated_data.extend_from_slice(&plaintext);
|
||||
}
|
||||
|
||||
let mut purported_data = Vec::with_capacity(authenticated_data.len());
|
||||
for range in transcript.sent_authed().iter_ranges() {
|
||||
purported_data.extend_from_slice(&transcript.sent_unsafe()[range]);
|
||||
}
|
||||
|
||||
for range in transcript.received_authed().iter_ranges() {
|
||||
purported_data.extend_from_slice(&transcript.received_unsafe()[range]);
|
||||
}
|
||||
|
||||
if purported_data != authenticated_data {
|
||||
return Err(InconsistentTranscript {});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Error for [`verify_transcript`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("inconsistent transcript")]
|
||||
pub struct InconsistentTranscript {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
@@ -54,18 +54,13 @@ impl ZkAesCtr {
|
||||
|
||||
let input = vm.alloc_vec::<U8>(len).map_err(ZkAesCtrError::vm)?;
|
||||
let keystream = self.aes.alloc_keystream(vm, len)?;
|
||||
let output = keystream.apply(vm, input)?;
|
||||
|
||||
match self.role {
|
||||
Role::Prover => vm.mark_private(input).map_err(ZkAesCtrError::vm)?,
|
||||
Role::Verifier => vm.mark_blind(input).map_err(ZkAesCtrError::vm)?,
|
||||
}
|
||||
|
||||
self.state = State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
};
|
||||
self.state = State::Ready { input, keystream };
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -78,23 +73,25 @@ impl ZkAesCtr {
|
||||
|
||||
/// Proves the encryption of `len` bytes.
|
||||
///
|
||||
/// Here we only assign certain values in the VM but the actual proving
|
||||
/// happens later when the plaintext is assigned and the VM is executed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `explicit_nonce` - Explicit nonce.
|
||||
/// * `len` - Length of the plaintext in bytes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A VM reference to the plaintext and the ciphertext.
|
||||
pub fn encrypt(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
explicit_nonce: Vec<u8>,
|
||||
len: usize,
|
||||
) -> Result<(Vector<U8>, Vector<U8>), ZkAesCtrError> {
|
||||
let State::Ready {
|
||||
input,
|
||||
keystream,
|
||||
output,
|
||||
} = &mut self.state
|
||||
else {
|
||||
let State::Ready { input, keystream } = &mut self.state else {
|
||||
Err(ErrorRepr::State {
|
||||
reason: "must be in ready state to encrypt",
|
||||
})?
|
||||
@@ -121,7 +118,7 @@ impl ZkAesCtr {
|
||||
|
||||
let mut input = input.split_off(input.len() - padded_len);
|
||||
let keystream = keystream.consume(padded_len)?;
|
||||
let mut output = output.split_off(output.len() - padded_len);
|
||||
let mut output = keystream.apply(vm, input)?;
|
||||
|
||||
// Assign counter block inputs.
|
||||
let mut ctr = START_CTR..;
|
||||
@@ -132,6 +129,8 @@ impl ZkAesCtr {
|
||||
// Assign zeroes to the padding.
|
||||
if padding_len > 0 {
|
||||
let padding = input.split_off(input.len() - padding_len);
|
||||
// To simplify the impl, we don't mark the padding as public, that's why only
|
||||
// the prover assigns it.
|
||||
if let Role::Prover = self.role {
|
||||
vm.assign(padding, vec![0; padding_len])
|
||||
.map_err(ZkAesCtrError::vm)?;
|
||||
@@ -149,7 +148,6 @@ enum State {
|
||||
Ready {
|
||||
input: Vector<U8>,
|
||||
keystream: Keystream<Nonce, Ctr, Block>,
|
||||
output: Vector<U8>,
|
||||
},
|
||||
Error,
|
||||
}
|
||||
@@ -5,9 +5,12 @@ description = "This crate provides implementations of ciphers for two parties"
|
||||
keywords = ["tls", "mpc", "2pc", "aes"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "cipher"
|
||||
|
||||
@@ -27,6 +30,5 @@ mpz-ot = { workspace = true }
|
||||
|
||||
tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
rand06-compat = { workspace = true }
|
||||
ctr = { workspace = true }
|
||||
cipher = { workspace = true }
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
use mpz_circuits::{circuits::aes128_trace, once_cell::sync::Lazy, trace, Circuit, CircuitBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// `fn(key: [u8; 16], iv: [u8; 4], nonce: [u8; 8], ctr: [u8; 4]) -> [u8; 16]`
|
||||
pub(crate) static AES128_CTR: Lazy<Arc<Circuit>> = Lazy::new(|| {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
let key = builder.add_array_input::<u8, 16>();
|
||||
let iv = builder.add_array_input::<u8, 4>();
|
||||
let nonce = builder.add_array_input::<u8, 8>();
|
||||
let ctr = builder.add_array_input::<u8, 4>();
|
||||
|
||||
let keystream = aes_ctr_trace(builder.state(), key, iv, nonce, ctr);
|
||||
|
||||
builder.add_output(keystream);
|
||||
|
||||
Arc::new(builder.build().unwrap())
|
||||
});
|
||||
|
||||
/// `fn(key: [u8; 16], msg: [u8; 16]) -> [u8; 16]`
|
||||
pub(crate) static AES128_ECB: Lazy<Arc<Circuit>> = Lazy::new(|| {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
let key = builder.add_array_input::<u8, 16>();
|
||||
let message = builder.add_array_input::<u8, 16>();
|
||||
let block = aes128_trace(builder.state(), key, message);
|
||||
|
||||
builder.add_output(block);
|
||||
|
||||
Arc::new(builder.build().unwrap())
|
||||
});
|
||||
|
||||
#[trace]
|
||||
#[dep(aes_128, aes128_trace)]
|
||||
#[allow(dead_code)]
|
||||
fn aes_ctr(key: [u8; 16], iv: [u8; 4], explicit_nonce: [u8; 8], ctr: [u8; 4]) -> [u8; 16] {
|
||||
let block: Vec<_> = iv.into_iter().chain(explicit_nonce).chain(ctr).collect();
|
||||
aes_128(key, block.try_into().unwrap())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn aes_128(key: [u8; 16], msg: [u8; 16]) -> [u8; 16] {
|
||||
use aes::{
|
||||
cipher::{BlockEncrypt, KeyInit},
|
||||
Aes128,
|
||||
};
|
||||
|
||||
let aes = Aes128::new_from_slice(&key).unwrap();
|
||||
let mut ciphertext = msg.into();
|
||||
aes.encrypt_block(&mut ciphertext);
|
||||
ciphertext.into()
|
||||
}
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
use crate::{Cipher, CtrBlock, Keystream};
|
||||
use async_trait::async_trait;
|
||||
use mpz_circuits::circuits::AES128;
|
||||
use mpz_memory_core::binary::{Binary, U8};
|
||||
use mpz_vm_core::{prelude::*, Call, Vm};
|
||||
use std::fmt::Debug;
|
||||
|
||||
mod circuit;
|
||||
mod error;
|
||||
|
||||
pub use error::AesError;
|
||||
@@ -55,7 +55,7 @@ impl Cipher for Aes128 {
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(circuit::AES128_ECB.clone())
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(input)
|
||||
.build()
|
||||
@@ -91,7 +91,7 @@ impl Cipher for Aes128 {
|
||||
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(circuit::AES128_CTR.clone())
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
@@ -145,7 +145,7 @@ impl Cipher for Aes128 {
|
||||
.map(|(explicit_nonce, counter)| {
|
||||
let output = vm
|
||||
.call(
|
||||
Call::builder(circuit::AES128_CTR.clone())
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(iv)
|
||||
.arg(explicit_nonce)
|
||||
@@ -172,7 +172,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::Cipher;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
correlated::Delta,
|
||||
@@ -181,7 +181,6 @@ mod tests {
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{Execute, Vm};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rand06_compat::Rand0_6CompatExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aes_ctr() {
|
||||
@@ -297,11 +296,11 @@ mod tests {
|
||||
|
||||
fn mock_vm() -> (impl Vm<Binary>, impl Vm<Binary>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Generator::new(cot_send, [0u8; 16], delta);
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
//! Ciphers and circuits.
|
||||
|
||||
use mpz_circuits::{types::ValueType, Circuit, CircuitBuilder, Tracer};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Builds a circuit which XORs the provided values.
|
||||
pub(crate) fn build_xor_circuit(inputs: &[ValueType]) -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
|
||||
for input_ty in inputs {
|
||||
let input_0 = builder.add_input_by_type(input_ty.clone());
|
||||
let input_1 = builder.add_input_by_type(input_ty.clone());
|
||||
|
||||
let input_0 = Tracer::new(builder.state(), input_0);
|
||||
let input_1 = Tracer::new(builder.state(), input_1);
|
||||
let output = input_0 ^ input_1;
|
||||
builder.add_output(output);
|
||||
}
|
||||
|
||||
let circ = builder.build().expect("circuit should be valid");
|
||||
|
||||
Arc::new(circ)
|
||||
}
|
||||
@@ -10,17 +10,15 @@
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
pub mod aes;
|
||||
mod circuit;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use circuit::build_xor_circuit;
|
||||
use mpz_circuits::types::ValueType;
|
||||
use mpz_circuits::circuits::xor;
|
||||
use mpz_memory_core::{
|
||||
binary::{Binary, U8},
|
||||
FromRaw, MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
|
||||
MemoryExt, Repr, Slice, StaticSize, ToRaw, Vector,
|
||||
};
|
||||
use mpz_vm_core::{prelude::*, CallBuilder, CallError, Vm};
|
||||
use std::collections::VecDeque;
|
||||
use mpz_vm_core::{prelude::*, Call, CallBuilder, CallError, Vm};
|
||||
use std::{collections::VecDeque, sync::Arc};
|
||||
|
||||
/// Provides computation of 2PC ciphers in counter and ECB mode.
|
||||
///
|
||||
@@ -99,6 +97,7 @@ pub struct CtrBlock<N, C, O> {
|
||||
/// Can be used to XOR with the cipher input to operate the cipher in counter
|
||||
/// mode.
|
||||
pub struct Keystream<N, C, O> {
|
||||
/// Sequential keystream blocks. Outputs are stored in contiguous memory.
|
||||
blocks: VecDeque<CtrBlock<N, C, O>>,
|
||||
}
|
||||
|
||||
@@ -117,25 +116,7 @@ where
|
||||
O: Repr<Binary> + StaticSize<Binary> + Copy,
|
||||
{
|
||||
/// Creates a new keystream from the provided blocks.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// * If the output of the keystream is not ordered and contiguous in
|
||||
/// memory.
|
||||
pub fn new(blocks: &[CtrBlock<N, C, O>]) -> Self {
|
||||
let mut pos = blocks
|
||||
.first()
|
||||
.map(|block| block.output.to_raw().ptr().as_usize())
|
||||
.unwrap_or(0);
|
||||
|
||||
for block in blocks {
|
||||
if block.output.to_raw().ptr().as_usize() != pos {
|
||||
panic!("output of keystream blocks must be ordered and contiguous in memory");
|
||||
}
|
||||
|
||||
pos += O::SIZE;
|
||||
}
|
||||
|
||||
Self {
|
||||
blocks: VecDeque::from_iter(blocks.iter().copied()),
|
||||
}
|
||||
@@ -178,7 +159,7 @@ where
|
||||
return Err(CipherError::new("no keystream material available"));
|
||||
}
|
||||
|
||||
let xor = build_xor_circuit(&[ValueType::new_array::<u8>(self.block_size())]);
|
||||
let xor = Arc::new(xor(self.block_size() * 8));
|
||||
let mut pos = 0;
|
||||
let mut outputs = Vec::with_capacity(self.blocks.len());
|
||||
for block in &self.blocks {
|
||||
@@ -195,20 +176,17 @@ where
|
||||
pos += self.block_size();
|
||||
}
|
||||
|
||||
// Calls were performed contiguously, so the output data is contiguous.
|
||||
let ptr = outputs
|
||||
.first()
|
||||
.map(|output| output.to_raw().ptr())
|
||||
.expect("keystream is not empty");
|
||||
let size = self.blocks.len() * O::SIZE;
|
||||
|
||||
let output = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
let output = flatten_blocks(vm, outputs.iter().map(|block| block.to_raw()))?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Returns `len` bytes of the keystream as a vector.
|
||||
pub fn to_vector(&self, len: usize) -> Result<Vector<U8>, CipherError> {
|
||||
pub fn to_vector(
|
||||
&self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
len: usize,
|
||||
) -> Result<Vector<U8>, CipherError> {
|
||||
if len == 0 {
|
||||
return Err(CipherError::new("length must be greater than 0"));
|
||||
} else if self.blocks.is_empty() {
|
||||
@@ -220,14 +198,8 @@ where
|
||||
return Err(CipherError::new("length does not match keystream length"));
|
||||
}
|
||||
|
||||
let ptr = self
|
||||
.blocks
|
||||
.front()
|
||||
.map(|block| block.output.to_raw().ptr())
|
||||
.expect("block count should be greater than 0");
|
||||
let size = block_count * O::SIZE;
|
||||
|
||||
let mut keystream = Vector::<U8>::from_raw(Slice::new_unchecked(ptr, size));
|
||||
let mut keystream =
|
||||
flatten_blocks(vm, self.blocks.iter().map(|block| block.output.to_raw()))?;
|
||||
keystream.truncate(len);
|
||||
|
||||
Ok(keystream)
|
||||
@@ -273,6 +245,34 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn flatten_blocks(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
blocks: impl IntoIterator<Item = Slice>,
|
||||
) -> Result<Vector<U8>, CipherError> {
|
||||
use mpz_circuits::CircuitBuilder;
|
||||
|
||||
let blocks = blocks.into_iter().collect::<Vec<_>>();
|
||||
let len: usize = blocks.iter().map(|block| block.len()).sum();
|
||||
|
||||
let mut builder = CircuitBuilder::new();
|
||||
for _ in 0..len {
|
||||
let i = builder.add_input();
|
||||
let o = builder.add_id_gate(i);
|
||||
builder.add_output(o);
|
||||
}
|
||||
|
||||
let circuit = builder.build().expect("flatten circuit should be valid");
|
||||
|
||||
let mut builder = Call::builder(Arc::new(circuit));
|
||||
for block in blocks {
|
||||
builder = builder.arg(block);
|
||||
}
|
||||
|
||||
let call = builder.build().map_err(CipherError::new)?;
|
||||
|
||||
vm.call(call).map_err(CipherError::new)
|
||||
}
|
||||
|
||||
/// A cipher error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{source}")]
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
[package]
|
||||
name = "tlsn-deap"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
mpz-core = { workspace = true }
|
||||
mpz-common = { workspace = true }
|
||||
|
||||
@@ -4,19 +4,15 @@
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
use std::{
|
||||
mem,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
mod map;
|
||||
|
||||
use std::{mem, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mpz_common::{scoped_futures::ScopedFutureExt as _, Context};
|
||||
use mpz_common::Context;
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Slice, View},
|
||||
memory::{binary::Binary, DecodeFuture, Memory, Repr, Slice, View},
|
||||
Call, Callable, Execute, Vm, VmError,
|
||||
};
|
||||
use rangeset::{Difference, RangeSet, UnionMut};
|
||||
@@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
|
||||
role: Role,
|
||||
mpc: Arc<Mutex<Mpc>>,
|
||||
zk: Arc<Mutex<Zk>>,
|
||||
/// Private inputs of the follower.
|
||||
follower_inputs: RangeSet<usize>,
|
||||
/// Mapping between the memories of the MPC and ZK VMs.
|
||||
memory_map: map::MemoryMap,
|
||||
/// Ranges of the follower's private inputs in the MPC VM.
|
||||
follower_input_ranges: RangeSet<usize>,
|
||||
/// Private inputs of the follower in the MPC VM.
|
||||
follower_inputs: Vec<Slice>,
|
||||
/// Outputs of the follower from the ZK VM. The references
|
||||
/// correspond to the MPC VM.
|
||||
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
|
||||
/// Whether the memories of the two VMs are potentially desynchronized.
|
||||
desync: AtomicBool,
|
||||
}
|
||||
|
||||
impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
@@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
role,
|
||||
mpc: Arc::new(Mutex::new(mpc)),
|
||||
zk: Arc::new(Mutex::new(zk)),
|
||||
follower_inputs: RangeSet::default(),
|
||||
memory_map: map::MemoryMap::default(),
|
||||
follower_input_ranges: RangeSet::default(),
|
||||
follower_inputs: Vec::default(),
|
||||
outputs: Vec::default(),
|
||||
desync: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
|
||||
|
||||
/// Returns a mutable reference to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex is locked by another thread.
|
||||
pub fn zk(&self) -> MutexGuard<'_, Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.try_lock().unwrap()
|
||||
}
|
||||
|
||||
/// Returns an owned mutex guard to the ZK VM.
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// After calling this method, allocations will no longer be allowed in the
|
||||
/// DEAP VM as the memory will potentially be desynchronized.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the mutex is locked by another thread.
|
||||
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
|
||||
self.desync.store(true, Ordering::Relaxed);
|
||||
self.zk.clone().try_lock_owned().unwrap()
|
||||
}
|
||||
|
||||
/// Translates a value from the MPC VM address space to the ZK VM address
|
||||
/// space.
|
||||
pub fn translate<T: Repr<Binary>>(&self, value: T) -> Result<T, VmError> {
|
||||
self.memory_map.try_get(value.to_raw()).map(T::from_raw)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mpc(&self) -> MutexGuard<'_, Mpc> {
|
||||
self.mpc.try_lock().unwrap()
|
||||
@@ -124,18 +119,15 @@ where
|
||||
// MACs.
|
||||
let input_futs = self
|
||||
.follower_inputs
|
||||
.iter_ranges()
|
||||
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
|
||||
.iter()
|
||||
.map(|&input| mpc.decode_raw(input))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
mpc.execute_all(ctx).await?;
|
||||
|
||||
// Assign inputs to the ZK VM.
|
||||
for (mut decode, input) in input_futs
|
||||
.into_iter()
|
||||
.zip(self.follower_inputs.iter_ranges())
|
||||
{
|
||||
let input = Slice::from_range_unchecked(input);
|
||||
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
|
||||
let input = self.memory_map.try_get(input)?;
|
||||
|
||||
// Follower has already assigned the inputs.
|
||||
if let Role::Leader = self.role {
|
||||
@@ -184,31 +176,48 @@ where
|
||||
{
|
||||
type Error = VmError;
|
||||
|
||||
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
}
|
||||
fn is_alloc_raw(&self, slice: Slice) -> bool {
|
||||
self.mpc.try_lock().unwrap().is_alloc_raw(slice)
|
||||
}
|
||||
|
||||
self.zk.try_lock().unwrap().alloc_raw(size)?;
|
||||
self.mpc.try_lock().unwrap().alloc_raw(size)
|
||||
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
|
||||
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
|
||||
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;
|
||||
|
||||
self.memory_map.insert(mpc_slice, zk_slice);
|
||||
|
||||
Ok(mpc_slice)
|
||||
}
|
||||
|
||||
fn is_assigned_raw(&self, slice: Slice) -> bool {
|
||||
self.mpc.try_lock().unwrap().is_assigned_raw(slice)
|
||||
}
|
||||
|
||||
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
|
||||
self.zk
|
||||
self.mpc
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign_raw(slice, data.clone())?;
|
||||
self.mpc.try_lock().unwrap().assign_raw(slice, data)
|
||||
|
||||
self.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.assign_raw(self.memory_map.try_get(slice)?, data)
|
||||
}
|
||||
|
||||
fn is_committed_raw(&self, slice: Slice) -> bool {
|
||||
self.mpc.try_lock().unwrap().is_committed_raw(slice)
|
||||
}
|
||||
|
||||
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
// Follower's private inputs are not committed in the ZK VM until finalization.
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
|
||||
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
|
||||
let mut zk = self.zk.try_lock().unwrap();
|
||||
for input in input_minus_follower.iter_ranges() {
|
||||
zk.commit_raw(Slice::from_range_unchecked(input))?;
|
||||
zk.commit_raw(
|
||||
self.memory_map
|
||||
.try_get(Slice::from_range_unchecked(input))?,
|
||||
)?;
|
||||
}
|
||||
|
||||
self.mpc.try_lock().unwrap().commit_raw(slice)
|
||||
@@ -219,7 +228,11 @@ where
|
||||
}
|
||||
|
||||
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
|
||||
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
|
||||
let fut = self
|
||||
.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.decode_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.outputs.push((slice, fut));
|
||||
|
||||
self.mpc.try_lock().unwrap().decode_raw(slice)
|
||||
@@ -234,8 +247,11 @@ where
|
||||
type Error = VmError;
|
||||
|
||||
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
|
||||
self.mpc.try_lock().unwrap().mark_public_raw(slice)
|
||||
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
|
||||
self.zk
|
||||
.try_lock()
|
||||
.unwrap()
|
||||
.mark_public_raw(self.memory_map.try_get(slice)?)
|
||||
}
|
||||
|
||||
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
|
||||
@@ -243,14 +259,15 @@ where
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
zk.mark_private_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
|
||||
}
|
||||
Role::Follower => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_private_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,14 +279,15 @@ where
|
||||
let mut mpc = self.mpc.try_lock().unwrap();
|
||||
match self.role {
|
||||
Role::Leader => {
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
self.follower_inputs.union_mut(&slice.to_range());
|
||||
// Follower's private inputs will become public during finalization.
|
||||
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
|
||||
self.follower_input_ranges.union_mut(&slice.to_range());
|
||||
self.follower_inputs.push(slice);
|
||||
}
|
||||
Role::Follower => {
|
||||
zk.mark_blind_raw(slice)?;
|
||||
mpc.mark_blind_raw(slice)?;
|
||||
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -283,14 +301,21 @@ where
|
||||
Zk: Vm<Binary>,
|
||||
{
|
||||
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
|
||||
if self.desync.load(Ordering::Relaxed) {
|
||||
return Err(VmError::memory(
|
||||
"DEAP VM memories are potentially desynchronized",
|
||||
));
|
||||
let (circ, inputs) = call.clone().into_parts();
|
||||
let mut builder = Call::builder(circ);
|
||||
|
||||
for input in inputs {
|
||||
builder = builder.arg(self.memory_map.try_get(input)?);
|
||||
}
|
||||
|
||||
self.zk.try_lock().unwrap().call_raw(call.clone())?;
|
||||
self.mpc.try_lock().unwrap().call_raw(call)
|
||||
let zk_call = builder.build().expect("call should be valid");
|
||||
|
||||
let output = self.mpc.try_lock().unwrap().call_raw(call)?;
|
||||
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;
|
||||
|
||||
self.memory_map.insert(output, zk_output);
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -308,8 +333,8 @@ where
|
||||
let mut zk = self.zk.clone().try_lock_owned().unwrap();
|
||||
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
|
||||
ctx.try_join(
|
||||
|ctx| async move { zk.flush(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { mpc.flush(ctx).await }.scope_boxed(),
|
||||
async move |ctx| zk.flush(ctx).await,
|
||||
async move |ctx| mpc.flush(ctx).await,
|
||||
)
|
||||
.await
|
||||
.map_err(VmError::execute)??;
|
||||
@@ -326,8 +351,8 @@ where
|
||||
let mut zk = self.zk.clone().try_lock_owned().unwrap();
|
||||
let mut mpc = self.mpc.clone().try_lock_owned().unwrap();
|
||||
ctx.try_join(
|
||||
|ctx| async move { zk.preprocess(ctx).await }.scope_boxed(),
|
||||
|ctx| async move { mpc.preprocess(ctx).await }.scope_boxed(),
|
||||
async move |ctx| zk.preprocess(ctx).await,
|
||||
async move |ctx| mpc.preprocess(ctx).await,
|
||||
)
|
||||
.await
|
||||
.map_err(VmError::execute)??;
|
||||
@@ -360,7 +385,7 @@ mod tests {
|
||||
use mpz_circuits::circuits::AES128;
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::{cot::ideal_cot, rcot::ideal_rcot};
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
@@ -368,21 +393,20 @@ mod tests {
|
||||
};
|
||||
use mpz_zk::{Prover, Verifier};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rand06_compat::Rand0_6CompatExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deap() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta_zk = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let gb = Generator::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta_zk, rcot_send);
|
||||
@@ -452,19 +476,103 @@ mod tests {
|
||||
assert_eq!(ct_leader, ct_follower);
|
||||
}
|
||||
|
||||
// Tests that the leader can not use different inputs in each VM without
|
||||
// detection by the follower.
|
||||
#[tokio::test]
|
||||
async fn test_malicious() {
|
||||
async fn test_deap_desync_memory() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta_zk = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let gb = Generator::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta_zk, rcot_send);
|
||||
|
||||
let mut leader = Deap::new(Role::Leader, gb, prover);
|
||||
let mut follower = Deap::new(Role::Follower, ev, verifier);
|
||||
|
||||
// Desynchronize the memories.
|
||||
let _ = leader.zk().alloc_raw(1).unwrap();
|
||||
let _ = follower.zk().alloc_raw(1).unwrap();
|
||||
|
||||
let (ct_leader, ct_follower) = futures::join!(
|
||||
async {
|
||||
let key: Array<U8, 16> = leader.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = leader.alloc().unwrap();
|
||||
|
||||
leader.mark_private(key).unwrap();
|
||||
leader.mark_blind(msg).unwrap();
|
||||
leader.assign(key, [42u8; 16]).unwrap();
|
||||
leader.commit(key).unwrap();
|
||||
leader.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = leader
|
||||
.call(
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(msg)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let ct = leader.decode(ct).unwrap();
|
||||
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.execute(&mut ctx_a).await.unwrap();
|
||||
leader.flush(&mut ctx_a).await.unwrap();
|
||||
leader.finalize(&mut ctx_a).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
},
|
||||
async {
|
||||
let key: Array<U8, 16> = follower.alloc().unwrap();
|
||||
let msg: Array<U8, 16> = follower.alloc().unwrap();
|
||||
|
||||
follower.mark_blind(key).unwrap();
|
||||
follower.mark_private(msg).unwrap();
|
||||
follower.assign(msg, [69u8; 16]).unwrap();
|
||||
follower.commit(key).unwrap();
|
||||
follower.commit(msg).unwrap();
|
||||
|
||||
let ct: Array<U8, 16> = follower
|
||||
.call(
|
||||
Call::builder(AES128.clone())
|
||||
.arg(key)
|
||||
.arg(msg)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let ct = follower.decode(ct).unwrap();
|
||||
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.execute(&mut ctx_b).await.unwrap();
|
||||
follower.flush(&mut ctx_b).await.unwrap();
|
||||
follower.finalize(&mut ctx_b).await.unwrap();
|
||||
|
||||
ct.await.unwrap()
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(ct_leader, ct_follower);
|
||||
}
|
||||
|
||||
// Tests that the leader can not use different inputs in each VM without
|
||||
// detection by the follower.
|
||||
#[tokio::test]
|
||||
async fn test_malicious() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta_mpc = Delta::random(&mut rng);
|
||||
let delta_zk = Delta::random(&mut rng);
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
|
||||
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
|
||||
|
||||
let gb = Garbler::new(cot_send, [1u8; 16], delta_mpc);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
let prover = Prover::new(rcot_recv);
|
||||
let verifier = Verifier::new(delta_zk, rcot_send);
|
||||
|
||||
111
crates/components/deap/src/map.rs
Normal file
111
crates/components/deap/src/map.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
use std::ops::Range;
|
||||
|
||||
use mpz_vm_core::{memory::Slice, VmError};
|
||||
use rangeset::Subset;
|
||||
|
||||
/// A mapping between the memories of the MPC and ZK VMs.
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct MemoryMap {
|
||||
mpc: Vec<Range<usize>>,
|
||||
zk: Vec<Range<usize>>,
|
||||
}
|
||||
|
||||
impl MemoryMap {
|
||||
/// Inserts a new allocation into the map.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// - If the slices are not inserted in the order they are allocated.
|
||||
/// - If the slices are not the same length.
|
||||
pub(crate) fn insert(&mut self, mpc: Slice, zk: Slice) {
|
||||
let mpc = mpc.to_range();
|
||||
let zk = zk.to_range();
|
||||
|
||||
assert_eq!(mpc.len(), zk.len(), "slices must be the same length");
|
||||
|
||||
if let Some(last) = self.mpc.last() {
|
||||
if last.end > mpc.start {
|
||||
panic!("slices must be provided in ascending order");
|
||||
}
|
||||
}
|
||||
|
||||
self.mpc.push(mpc);
|
||||
self.zk.push(zk);
|
||||
}
|
||||
|
||||
/// Returns the corresponding allocation in the ZK VM.
|
||||
pub(crate) fn try_get(&self, mpc: Slice) -> Result<Slice, VmError> {
|
||||
let mpc_range = mpc.to_range();
|
||||
let pos = match self
|
||||
.mpc
|
||||
.binary_search_by_key(&mpc_range.start, |range| range.start)
|
||||
{
|
||||
Ok(pos) => pos,
|
||||
Err(0) => return Err(VmError::memory(format!("invalid memory slice: {mpc}"))),
|
||||
Err(pos) => pos - 1,
|
||||
};
|
||||
|
||||
let candidate = &self.mpc[pos];
|
||||
if mpc_range.is_subset(candidate) {
|
||||
let offset = mpc_range.start - candidate.start;
|
||||
let start = self.zk[pos].start + offset;
|
||||
let slice = Slice::from_range_unchecked(start..start + mpc_range.len());
|
||||
|
||||
Ok(slice)
|
||||
} else {
|
||||
Err(VmError::memory(format!("invalid memory slice: {mpc}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_map() {
|
||||
let mut map = MemoryMap::default();
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(0..10),
|
||||
Slice::from_range_unchecked(10..20),
|
||||
);
|
||||
|
||||
// Range is fully contained.
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(0..10)).unwrap(),
|
||||
Slice::from_range_unchecked(10..20)
|
||||
);
|
||||
// Range is subset.
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(1..9)).unwrap(),
|
||||
Slice::from_range_unchecked(11..19)
|
||||
);
|
||||
// Range is not subset.
|
||||
assert!(map.try_get(Slice::from_range_unchecked(0..11)).is_err());
|
||||
|
||||
// Insert another range.
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(20..30),
|
||||
Slice::from_range_unchecked(30..40),
|
||||
);
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(20..30)).unwrap(),
|
||||
Slice::from_range_unchecked(30..40)
|
||||
);
|
||||
assert_eq!(
|
||||
map.try_get(Slice::from_range_unchecked(21..29)).unwrap(),
|
||||
Slice::from_range_unchecked(31..39)
|
||||
);
|
||||
assert!(map.try_get(Slice::from_range_unchecked(19..21)).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_map_length_mismatch() {
|
||||
let mut map = MemoryMap::default();
|
||||
map.insert(
|
||||
Slice::from_range_unchecked(5..10),
|
||||
Slice::from_range_unchecked(20..30),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
[package]
|
||||
name = "tlsn-hmac-sha256-circuits"
|
||||
authors = ["TLSNotary Team"]
|
||||
description = "The 2PC circuits for TLS HMAC-SHA256 PRF"
|
||||
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.9"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
name = "hmac_sha256_circuits"
|
||||
|
||||
[dependencies]
|
||||
mpz-circuits = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
ring = { workspace = true }
|
||||
@@ -1,159 +0,0 @@
|
||||
use std::cell::RefCell;
|
||||
|
||||
use mpz_circuits::{
|
||||
circuits::{sha256, sha256_compress, sha256_compress_trace, sha256_trace},
|
||||
types::{U32, U8},
|
||||
BuilderState, Tracer,
|
||||
};
|
||||
|
||||
static SHA256_INITIAL_STATE: [u32; 8] = [
|
||||
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
|
||||
];
|
||||
|
||||
/// Returns the outer and inner states of HMAC-SHA256 with the provided key.
|
||||
///
|
||||
/// Outer state is H(key ⊕ opad)
|
||||
///
|
||||
/// Inner state is H(key ⊕ ipad)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `builder_state` - Reference to builder state.
|
||||
/// * `key` - N-byte key (must be <= 64 bytes).
|
||||
pub fn hmac_sha256_partial_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
key: &[Tracer<'a, U8>],
|
||||
) -> ([Tracer<'a, U32>; 8], [Tracer<'a, U32>; 8]) {
|
||||
assert!(key.len() <= 64);
|
||||
|
||||
let mut opad = [Tracer::new(
|
||||
builder_state,
|
||||
builder_state.borrow_mut().get_constant(0x5cu8),
|
||||
); 64];
|
||||
|
||||
let mut ipad = [Tracer::new(
|
||||
builder_state,
|
||||
builder_state.borrow_mut().get_constant(0x36u8),
|
||||
); 64];
|
||||
|
||||
key.iter().enumerate().for_each(|(i, k)| {
|
||||
opad[i] = opad[i] ^ *k;
|
||||
ipad[i] = ipad[i] ^ *k;
|
||||
});
|
||||
|
||||
let sha256_initial_state: [_; 8] = SHA256_INITIAL_STATE
|
||||
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
|
||||
|
||||
let outer_state = sha256_compress_trace(builder_state, sha256_initial_state, opad);
|
||||
let inner_state = sha256_compress_trace(builder_state, sha256_initial_state, ipad);
|
||||
|
||||
(outer_state, inner_state)
|
||||
}
|
||||
|
||||
/// Reference implementation of HMAC-SHA256 partial function.
|
||||
///
|
||||
/// Returns the outer and inner states of HMAC-SHA256 with the provided key.
|
||||
///
|
||||
/// Outer state is H(key ⊕ opad)
|
||||
///
|
||||
/// Inner state is H(key ⊕ ipad)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `key` - N-byte key (must be <= 64 bytes).
|
||||
pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) {
|
||||
assert!(key.len() <= 64);
|
||||
|
||||
let mut opad = [0x5cu8; 64];
|
||||
let mut ipad = [0x36u8; 64];
|
||||
|
||||
key.iter().enumerate().for_each(|(i, k)| {
|
||||
opad[i] ^= k;
|
||||
ipad[i] ^= k;
|
||||
});
|
||||
|
||||
let outer_state = sha256_compress(SHA256_INITIAL_STATE, opad);
|
||||
let inner_state = sha256_compress(SHA256_INITIAL_STATE, ipad);
|
||||
|
||||
(outer_state, inner_state)
|
||||
}
|
||||
|
||||
/// HMAC-SHA256 finalization function.
|
||||
///
|
||||
/// Returns the HMAC-SHA256 digest of the provided message using existing outer
|
||||
/// and inner states.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `outer_state` - 256-bit outer state.
|
||||
/// * `inner_state` - 256-bit inner state.
|
||||
/// * `msg` - N-byte message.
|
||||
pub fn hmac_sha256_finalize_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
outer_state: [Tracer<'a, U32>; 8],
|
||||
inner_state: [Tracer<'a, U32>; 8],
|
||||
msg: &[Tracer<'a, U8>],
|
||||
) -> [Tracer<'a, U8>; 32] {
|
||||
sha256_trace(
|
||||
builder_state,
|
||||
outer_state,
|
||||
64,
|
||||
&sha256_trace(builder_state, inner_state, 64, msg),
|
||||
)
|
||||
}
|
||||
|
||||
/// Reference implementation of the HMAC-SHA256 finalization function.
|
||||
///
|
||||
/// Returns the HMAC-SHA256 digest of the provided message using existing outer
|
||||
/// and inner states.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `outer_state` - 256-bit outer state.
|
||||
/// * `inner_state` - 256-bit inner state.
|
||||
/// * `msg` - N-byte message.
|
||||
pub fn hmac_sha256_finalize(outer_state: [u32; 8], inner_state: [u32; 8], msg: &[u8]) -> [u8; 32] {
|
||||
sha256(outer_state, 64, &sha256(inner_state, 64, msg))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_circuits::{test_circ, CircuitBuilder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha256_partial() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let key = builder.add_array_input::<u8, 48>();
|
||||
let (outer_state, inner_state) = hmac_sha256_partial_trace(builder.state(), &key);
|
||||
builder.add_output(outer_state);
|
||||
builder.add_output(inner_state);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let key = [69u8; 48];
|
||||
|
||||
test_circ!(circ, hmac_sha256_partial, fn(&key) -> ([u32; 8], [u32; 8]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha256_finalize() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let outer_state = builder.add_array_input::<u32, 8>();
|
||||
let inner_state = builder.add_array_input::<u32, 8>();
|
||||
let msg = builder.add_array_input::<u8, 47>();
|
||||
let hash = hmac_sha256_finalize_trace(builder.state(), outer_state, inner_state, &msg);
|
||||
builder.add_output(hash);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let key = [69u8; 32];
|
||||
let (outer_state, inner_state) = hmac_sha256_partial(&key);
|
||||
let msg = [42u8; 47];
|
||||
|
||||
test_circ!(
|
||||
circ,
|
||||
hmac_sha256_finalize,
|
||||
fn(outer_state, inner_state, &msg) -> [u8; 32]
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
//! HMAC-SHA256 circuits.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod hmac_sha256;
|
||||
mod prf;
|
||||
mod session_keys;
|
||||
mod verify_data;
|
||||
|
||||
pub use hmac_sha256::{
|
||||
hmac_sha256_finalize, hmac_sha256_finalize_trace, hmac_sha256_partial,
|
||||
hmac_sha256_partial_trace,
|
||||
};
|
||||
|
||||
pub use prf::{prf, prf_trace};
|
||||
pub use session_keys::{session_keys, session_keys_trace};
|
||||
pub use verify_data::{verify_data, verify_data_trace};
|
||||
|
||||
use mpz_circuits::{Circuit, CircuitBuilder, Tracer};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Builds session key derivation circuit.
|
||||
#[tracing::instrument(level = "trace")]
|
||||
pub fn build_session_keys() -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
let pms = builder.add_array_input::<u8, 32>();
|
||||
let client_random = builder.add_array_input::<u8, 32>();
|
||||
let server_random = builder.add_array_input::<u8, 32>();
|
||||
let (cwk, swk, civ, siv, outer_state, inner_state) =
|
||||
session_keys_trace(builder.state(), pms, client_random, server_random);
|
||||
builder.add_output(cwk);
|
||||
builder.add_output(swk);
|
||||
builder.add_output(civ);
|
||||
builder.add_output(siv);
|
||||
builder.add_output(outer_state);
|
||||
builder.add_output(inner_state);
|
||||
Arc::new(builder.build().expect("session keys should build"))
|
||||
}
|
||||
|
||||
/// Builds a verify data circuit.
|
||||
#[tracing::instrument(level = "trace")]
|
||||
pub fn build_verify_data(label: &[u8]) -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
let outer_state = builder.add_array_input::<u32, 8>();
|
||||
let inner_state = builder.add_array_input::<u32, 8>();
|
||||
let handshake_hash = builder.add_array_input::<u8, 32>();
|
||||
let vd = verify_data_trace(
|
||||
builder.state(),
|
||||
outer_state,
|
||||
inner_state,
|
||||
&label
|
||||
.iter()
|
||||
.map(|v| Tracer::new(builder.state(), builder.get_constant(*v).to_inner()))
|
||||
.collect::<Vec<_>>(),
|
||||
handshake_hash,
|
||||
);
|
||||
builder.add_output(vd);
|
||||
Arc::new(builder.build().expect("verify data should build"))
|
||||
}
|
||||
@@ -1,227 +0,0 @@
|
||||
//! This module provides an implementation of the HMAC-SHA256 PRF defined in [RFC 5246](https://www.rfc-editor.org/rfc/rfc5246#section-5).
|
||||
|
||||
use std::cell::RefCell;
|
||||
|
||||
use mpz_circuits::{
|
||||
types::{U32, U8},
|
||||
BuilderState, Tracer,
|
||||
};
|
||||
|
||||
use crate::hmac_sha256::{hmac_sha256_finalize, hmac_sha256_finalize_trace};
|
||||
|
||||
fn p_hash_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
outer_state: [Tracer<'a, U32>; 8],
|
||||
inner_state: [Tracer<'a, U32>; 8],
|
||||
seed: &[Tracer<'a, U8>],
|
||||
iterations: usize,
|
||||
) -> Vec<Tracer<'a, U8>> {
|
||||
// A() is defined as:
|
||||
//
|
||||
// A(0) = seed
|
||||
// A(i) = HMAC_hash(secret, A(i-1))
|
||||
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
|
||||
a_cache.push(seed.to_vec());
|
||||
|
||||
for i in 0..iterations {
|
||||
let a_i = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_cache[i]);
|
||||
a_cache.push(a_i.to_vec());
|
||||
}
|
||||
|
||||
// HMAC_hash(secret, A(i) + seed)
|
||||
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
|
||||
for i in 0..iterations {
|
||||
let mut a_i_seed = a_cache[i + 1].clone();
|
||||
a_i_seed.extend_from_slice(seed);
|
||||
|
||||
let hash = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_i_seed);
|
||||
output.extend_from_slice(&hash);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: usize) -> Vec<u8> {
|
||||
// A() is defined as:
|
||||
//
|
||||
// A(0) = seed
|
||||
// A(i) = HMAC_hash(secret, A(i-1))
|
||||
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
|
||||
a_cache.push(seed.to_vec());
|
||||
|
||||
for i in 0..iterations {
|
||||
let a_i = hmac_sha256_finalize(outer_state, inner_state, &a_cache[i]);
|
||||
a_cache.push(a_i.to_vec());
|
||||
}
|
||||
|
||||
// HMAC_hash(secret, A(i) + seed)
|
||||
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
|
||||
for i in 0..iterations {
|
||||
let mut a_i_seed = a_cache[i + 1].clone();
|
||||
a_i_seed.extend_from_slice(seed);
|
||||
|
||||
let hash = hmac_sha256_finalize(outer_state, inner_state, &a_i_seed);
|
||||
output.extend_from_slice(&hash);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Computes PRF(secret, label, seed).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `builder_state` - Reference to builder state.
|
||||
/// * `outer_state` - The outer state of HMAC-SHA256.
|
||||
/// * `inner_state` - The inner state of HMAC-SHA256.
|
||||
/// * `seed` - The seed to use.
|
||||
/// * `label` - The label to use.
|
||||
/// * `bytes` - The number of bytes to output.
|
||||
pub fn prf_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
outer_state: [Tracer<'a, U32>; 8],
|
||||
inner_state: [Tracer<'a, U32>; 8],
|
||||
seed: &[Tracer<'a, U8>],
|
||||
label: &[Tracer<'a, U8>],
|
||||
bytes: usize,
|
||||
) -> Vec<Tracer<'a, U8>> {
|
||||
let iterations = bytes / 32 + (bytes % 32 != 0) as usize;
|
||||
let mut label_seed = label.to_vec();
|
||||
label_seed.extend_from_slice(seed);
|
||||
|
||||
let mut output = p_hash_trace(
|
||||
builder_state,
|
||||
outer_state,
|
||||
inner_state,
|
||||
&label_seed,
|
||||
iterations,
|
||||
);
|
||||
output.truncate(bytes);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Reference implementation of PRF(secret, label, seed).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `outer_state` - The outer state of HMAC-SHA256.
|
||||
/// * `inner_state` - The inner state of HMAC-SHA256.
|
||||
/// * `seed` - The seed to use.
|
||||
/// * `label` - The label to use.
|
||||
/// * `bytes` - The number of bytes to output.
|
||||
pub fn prf(
|
||||
outer_state: [u32; 8],
|
||||
inner_state: [u32; 8],
|
||||
seed: &[u8],
|
||||
label: &[u8],
|
||||
bytes: usize,
|
||||
) -> Vec<u8> {
|
||||
let iterations = bytes / 32 + (bytes % 32 != 0) as usize;
|
||||
let mut label_seed = label.to_vec();
|
||||
label_seed.extend_from_slice(seed);
|
||||
|
||||
let mut output = p_hash(outer_state, inner_state, &label_seed, iterations);
|
||||
output.truncate(bytes);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_circuits::{evaluate, CircuitBuilder};
|
||||
|
||||
use crate::hmac_sha256::hmac_sha256_partial;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_p_hash() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let outer_state = builder.add_array_input::<u32, 8>();
|
||||
let inner_state = builder.add_array_input::<u32, 8>();
|
||||
let seed = builder.add_array_input::<u8, 64>();
|
||||
let output = p_hash_trace(builder.state(), outer_state, inner_state, &seed, 2);
|
||||
builder.add_output(output);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let outer_state = [0u32; 8];
|
||||
let inner_state = [1u32; 8];
|
||||
let seed = [42u8; 64];
|
||||
|
||||
let expected = p_hash(outer_state, inner_state, &seed, 2);
|
||||
let actual = evaluate!(circ, fn(outer_state, inner_state, &seed) -> Vec<u8>).unwrap();
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prf() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let outer_state = builder.add_array_input::<u32, 8>();
|
||||
let inner_state = builder.add_array_input::<u32, 8>();
|
||||
let seed = builder.add_array_input::<u8, 64>();
|
||||
let label = builder.add_array_input::<u8, 13>();
|
||||
let output = prf_trace(builder.state(), outer_state, inner_state, &seed, &label, 48);
|
||||
builder.add_output(output);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let master_secret = [0u8; 48];
|
||||
let seed = [43u8; 64];
|
||||
let label = b"master secret";
|
||||
|
||||
let (outer_state, inner_state) = hmac_sha256_partial(&master_secret);
|
||||
|
||||
let expected = prf(outer_state, inner_state, &seed, label, 48);
|
||||
let actual =
|
||||
evaluate!(circ, fn(outer_state, inner_state, &seed, label) -> Vec<u8>).unwrap();
|
||||
|
||||
assert_eq!(actual, expected);
|
||||
|
||||
let mut expected_ring = [0u8; 48];
|
||||
ring_prf::prf(&mut expected_ring, &master_secret, label, &seed);
|
||||
|
||||
assert_eq!(actual, expected_ring);
|
||||
}
|
||||
|
||||
// Borrowed from Rustls for testing
|
||||
// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs
|
||||
mod ring_prf {
|
||||
use ring::{hmac, hmac::HMAC_SHA256};
|
||||
|
||||
fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag {
|
||||
let mut ctx = hmac::Context::with_key(key);
|
||||
ctx.update(a);
|
||||
ctx.update(b);
|
||||
ctx.sign()
|
||||
}
|
||||
|
||||
fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) {
|
||||
let hmac_key = hmac::Key::new(HMAC_SHA256, secret);
|
||||
|
||||
// A(1)
|
||||
let mut current_a = hmac::sign(&hmac_key, seed);
|
||||
let chunk_size = HMAC_SHA256.digest_algorithm().output_len();
|
||||
for chunk in out.chunks_mut(chunk_size) {
|
||||
// P_hash[i] = HMAC_hash(secret, A(i) + seed)
|
||||
let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed);
|
||||
chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]);
|
||||
|
||||
// A(i+1) = HMAC_hash(secret, A(i))
|
||||
current_a = hmac::sign(&hmac_key, current_a.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
fn concat(a: &[u8], b: &[u8]) -> Vec<u8> {
|
||||
let mut ret = Vec::new();
|
||||
ret.extend_from_slice(a);
|
||||
ret.extend_from_slice(b);
|
||||
ret
|
||||
}
|
||||
|
||||
pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) {
|
||||
let joined_seed = concat(label, seed);
|
||||
p(out, secret, &joined_seed);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,200 +0,0 @@
|
||||
use std::cell::RefCell;
|
||||
|
||||
use mpz_circuits::{
|
||||
types::{U32, U8},
|
||||
BuilderState, Tracer,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
hmac_sha256::{hmac_sha256_partial, hmac_sha256_partial_trace},
|
||||
prf::{prf, prf_trace},
|
||||
};
|
||||
|
||||
/// Session Keys.
|
||||
///
|
||||
/// Computes expanded p1 which consists of client_write_key + server_write_key.
|
||||
/// Computes expanded p2 which consists of client_IV + server_IV.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `builder_state` - Reference to builder state.
|
||||
/// * `pms` - 32-byte premaster secret.
|
||||
/// * `client_random` - 32-byte client random.
|
||||
/// * `server_random` - 32-byte server random.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `client_write_key` - 16-byte client write key.
|
||||
/// * `server_write_key` - 16-byte server write key.
|
||||
/// * `client_IV` - 4-byte client IV.
|
||||
/// * `server_IV` - 4-byte server IV.
|
||||
/// * `outer_hash_state` - 256-bit master-secret outer HMAC state.
|
||||
/// * `inner_hash_state` - 256-bit master-secret inner HMAC state.
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn session_keys_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
pms: [Tracer<'a, U8>; 32],
|
||||
client_random: [Tracer<'a, U8>; 32],
|
||||
server_random: [Tracer<'a, U8>; 32],
|
||||
) -> (
|
||||
[Tracer<'a, U8>; 16],
|
||||
[Tracer<'a, U8>; 16],
|
||||
[Tracer<'a, U8>; 4],
|
||||
[Tracer<'a, U8>; 4],
|
||||
[Tracer<'a, U32>; 8],
|
||||
[Tracer<'a, U32>; 8],
|
||||
) {
|
||||
let (pms_outer_state, pms_inner_state) = hmac_sha256_partial_trace(builder_state, &pms);
|
||||
|
||||
let master_secret = {
|
||||
let seed = client_random
|
||||
.iter()
|
||||
.chain(&server_random)
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let label = b"master secret"
|
||||
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
|
||||
|
||||
prf_trace(
|
||||
builder_state,
|
||||
pms_outer_state,
|
||||
pms_inner_state,
|
||||
&seed,
|
||||
&label,
|
||||
48,
|
||||
)
|
||||
};
|
||||
|
||||
let (master_secret_outer_state, master_secret_inner_state) =
|
||||
hmac_sha256_partial_trace(builder_state, &master_secret);
|
||||
|
||||
let key_material = {
|
||||
let seed = server_random
|
||||
.iter()
|
||||
.chain(&client_random)
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let label = b"key expansion"
|
||||
.map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v)));
|
||||
|
||||
prf_trace(
|
||||
builder_state,
|
||||
master_secret_outer_state,
|
||||
master_secret_inner_state,
|
||||
&seed,
|
||||
&label,
|
||||
40,
|
||||
)
|
||||
};
|
||||
|
||||
let cwk = key_material[0..16].try_into().unwrap();
|
||||
let swk = key_material[16..32].try_into().unwrap();
|
||||
let civ = key_material[32..36].try_into().unwrap();
|
||||
let siv = key_material[36..40].try_into().unwrap();
|
||||
|
||||
(
|
||||
cwk,
|
||||
swk,
|
||||
civ,
|
||||
siv,
|
||||
master_secret_outer_state,
|
||||
master_secret_inner_state,
|
||||
)
|
||||
}
|
||||
|
||||
/// Reference implementation of session keys derivation.
|
||||
pub fn session_keys(
|
||||
pms: [u8; 32],
|
||||
client_random: [u8; 32],
|
||||
server_random: [u8; 32],
|
||||
) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4]) {
|
||||
let (pms_outer_state, pms_inner_state) = hmac_sha256_partial(&pms);
|
||||
|
||||
let master_secret = {
|
||||
let seed = client_random
|
||||
.iter()
|
||||
.chain(&server_random)
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let label = b"master secret";
|
||||
|
||||
prf(pms_outer_state, pms_inner_state, &seed, label, 48)
|
||||
};
|
||||
|
||||
let (master_secret_outer_state, master_secret_inner_state) =
|
||||
hmac_sha256_partial(&master_secret);
|
||||
|
||||
let key_material = {
|
||||
let seed = server_random
|
||||
.iter()
|
||||
.chain(&client_random)
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let label = b"key expansion";
|
||||
|
||||
prf(
|
||||
master_secret_outer_state,
|
||||
master_secret_inner_state,
|
||||
&seed,
|
||||
label,
|
||||
40,
|
||||
)
|
||||
};
|
||||
|
||||
let cwk = key_material[0..16].try_into().unwrap();
|
||||
let swk = key_material[16..32].try_into().unwrap();
|
||||
let civ = key_material[32..36].try_into().unwrap();
|
||||
let siv = key_material[36..40].try_into().unwrap();
|
||||
|
||||
(cwk, swk, civ, siv)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_circuits::{evaluate, CircuitBuilder};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_session_keys() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let pms = builder.add_array_input::<u8, 32>();
|
||||
let client_random = builder.add_array_input::<u8, 32>();
|
||||
let server_random = builder.add_array_input::<u8, 32>();
|
||||
let (cwk, swk, civ, siv, outer_state, inner_state) =
|
||||
session_keys_trace(builder.state(), pms, client_random, server_random);
|
||||
builder.add_output(cwk);
|
||||
builder.add_output(swk);
|
||||
builder.add_output(civ);
|
||||
builder.add_output(siv);
|
||||
builder.add_output(outer_state);
|
||||
builder.add_output(inner_state);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let pms = [0u8; 32];
|
||||
let client_random = [42u8; 32];
|
||||
let server_random = [69u8; 32];
|
||||
|
||||
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
|
||||
session_keys(pms, client_random, server_random);
|
||||
|
||||
let (cwk, swk, civ, siv, _, _) = evaluate!(
|
||||
circ,
|
||||
fn(
|
||||
pms,
|
||||
client_random,
|
||||
server_random,
|
||||
) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4], [u32; 8], [u32; 8])
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(cwk, expected_cwk);
|
||||
assert_eq!(swk, expected_swk);
|
||||
assert_eq!(civ, expected_civ);
|
||||
assert_eq!(siv, expected_siv);
|
||||
}
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
use std::cell::RefCell;
|
||||
|
||||
use mpz_circuits::{
|
||||
types::{U32, U8},
|
||||
BuilderState, Tracer,
|
||||
};
|
||||
|
||||
use crate::prf::{prf, prf_trace};
|
||||
|
||||
/// Computes verify_data as specified in RFC 5246, Section 7.4.9.
|
||||
///
|
||||
/// verify_data
|
||||
/// PRF(master_secret, finished_label,
|
||||
/// Hash(handshake_messages))[0..verify_data_length-1];
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `builder_state` - The builder state.
|
||||
/// * `outer_state` - The outer HMAC state of the master secret.
|
||||
/// * `inner_state` - The inner HMAC state of the master secret.
|
||||
/// * `label` - The label to use.
|
||||
/// * `hs_hash` - The handshake hash.
|
||||
pub fn verify_data_trace<'a>(
|
||||
builder_state: &'a RefCell<BuilderState>,
|
||||
outer_state: [Tracer<'a, U32>; 8],
|
||||
inner_state: [Tracer<'a, U32>; 8],
|
||||
label: &[Tracer<'a, U8>],
|
||||
hs_hash: [Tracer<'a, U8>; 32],
|
||||
) -> [Tracer<'a, U8>; 12] {
|
||||
let vd = prf_trace(builder_state, outer_state, inner_state, &hs_hash, label, 12);
|
||||
|
||||
vd.try_into().expect("vd is 12 bytes")
|
||||
}
|
||||
|
||||
/// Reference implementation of verify_data as specified in RFC 5246, Section
|
||||
/// 7.4.9.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `outer_state` - The outer HMAC state of the master secret.
|
||||
/// * `inner_state` - The inner HMAC state of the master secret.
|
||||
/// * `label` - The label to use.
|
||||
/// * `hs_hash` - The handshake hash.
|
||||
pub fn verify_data(
|
||||
outer_state: [u32; 8],
|
||||
inner_state: [u32; 8],
|
||||
label: &[u8],
|
||||
hs_hash: [u8; 32],
|
||||
) -> [u8; 12] {
|
||||
let vd = prf(outer_state, inner_state, &hs_hash, label, 12);
|
||||
|
||||
vd.try_into().expect("vd is 12 bytes")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
use mpz_circuits::{evaluate, CircuitBuilder};
|
||||
|
||||
const CF_LABEL: &[u8; 15] = b"client finished";
|
||||
|
||||
#[test]
|
||||
fn test_verify_data() {
|
||||
let builder = CircuitBuilder::new();
|
||||
let outer_state = builder.add_array_input::<u32, 8>();
|
||||
let inner_state = builder.add_array_input::<u32, 8>();
|
||||
let label = builder.add_array_input::<u8, 15>();
|
||||
let hs_hash = builder.add_array_input::<u8, 32>();
|
||||
let vd = verify_data_trace(builder.state(), outer_state, inner_state, &label, hs_hash);
|
||||
builder.add_output(vd);
|
||||
let circ = builder.build().unwrap();
|
||||
|
||||
let outer_state = [0u32; 8];
|
||||
let inner_state = [1u32; 8];
|
||||
let hs_hash = [42u8; 32];
|
||||
|
||||
let expected = prf(outer_state, inner_state, &hs_hash, CF_LABEL, 12);
|
||||
|
||||
let actual = evaluate!(
|
||||
circ,
|
||||
fn(outer_state, inner_state, CF_LABEL, hs_hash) -> [u8; 12]
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(actual.to_vec(), expected);
|
||||
}
|
||||
}
|
||||
@@ -5,28 +5,24 @@ description = "A 2PC implementation of TLS HMAC-SHA256 PRF"
|
||||
keywords = ["tls", "mpc", "2pc", "hmac", "sha256"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "hmac_sha256"
|
||||
|
||||
[features]
|
||||
default = ["mock"]
|
||||
rayon = ["mpz-common/rayon"]
|
||||
mock = []
|
||||
|
||||
[dependencies]
|
||||
tlsn-hmac-sha256-circuits = { workspace = true }
|
||||
|
||||
mpz-vm-core = { workspace = true }
|
||||
mpz-core = { workspace = true }
|
||||
mpz-circuits = { workspace = true }
|
||||
mpz-common = { workspace = true, features = ["cpu"] }
|
||||
mpz-hash = { workspace = true }
|
||||
|
||||
derive_builder = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
mpz-ot = { workspace = true, features = ["ideal"] }
|
||||
@@ -36,7 +32,8 @@ mpz-common = { workspace = true, features = ["test-utils"] }
|
||||
criterion = { workspace = true, features = ["async_tokio"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
|
||||
rand = { workspace = true }
|
||||
rand06-compat = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
ring = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
name = "prf"
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion};
|
||||
|
||||
use hmac_sha256::{MpcPrf, PrfConfig, Role};
|
||||
use hmac_sha256::{Mode, MpcPrf};
|
||||
use mpz_common::context::test_mt_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, correlated::Delta, Array},
|
||||
prelude::*,
|
||||
Execute,
|
||||
};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rand06_compat::Rand0_6CompatExt;
|
||||
|
||||
#[allow(clippy::unit_arg)]
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
@@ -19,13 +19,16 @@ fn criterion_benchmark(c: &mut Criterion) {
|
||||
group.sample_size(10);
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
|
||||
group.bench_function("prf", |b| b.to_async(&rt).iter(prf));
|
||||
group.bench_function("prf_normal", |b| b.to_async(&rt).iter(|| prf(Mode::Normal)));
|
||||
group.bench_function("prf_reduced", |b| {
|
||||
b.to_async(&rt).iter(|| prf(Mode::Reduced))
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
|
||||
async fn prf() {
|
||||
async fn prf(mode: Mode) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
|
||||
let pms = [42u8; 32];
|
||||
@@ -36,10 +39,10 @@ async fn prf() {
|
||||
let mut leader_ctx = leader_exec.new_context().await.unwrap();
|
||||
let mut follower_ctx = follower_exec.new_context().await.unwrap();
|
||||
|
||||
let delta = Delta::random(&mut rng.compat_by_ref());
|
||||
let delta = Delta::random(&mut rng);
|
||||
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
|
||||
let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta);
|
||||
let mut follower_vm = Evaluator::new(ot_recv);
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
|
||||
@@ -52,23 +55,17 @@ async fn prf() {
|
||||
follower_vm.assign(follower_pms, pms).unwrap();
|
||||
follower_vm.commit(follower_pms).unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
|
||||
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
|
||||
let mut leader = MpcPrf::new(mode);
|
||||
let mut follower = MpcPrf::new(mode);
|
||||
|
||||
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
|
||||
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
|
||||
|
||||
leader
|
||||
.set_client_random(&mut leader_vm, Some(client_random))
|
||||
.unwrap();
|
||||
follower.set_client_random(&mut follower_vm, None).unwrap();
|
||||
leader.set_client_random(client_random).unwrap();
|
||||
follower.set_client_random(client_random).unwrap();
|
||||
|
||||
leader
|
||||
.set_server_random(&mut leader_vm, server_random)
|
||||
.unwrap();
|
||||
follower
|
||||
.set_server_random(&mut follower_vm, server_random)
|
||||
.unwrap();
|
||||
leader.set_server_random(server_random).unwrap();
|
||||
follower.set_server_random(server_random).unwrap();
|
||||
|
||||
let _ = leader_vm
|
||||
.decode(leader_output.keys.client_write_key)
|
||||
@@ -88,44 +85,61 @@ async fn prf() {
|
||||
let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
|
||||
let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
while leader.wants_flush() || follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
leader.flush(&mut leader_vm).unwrap();
|
||||
leader_vm.execute_all(&mut leader_ctx).await
|
||||
},
|
||||
async {
|
||||
follower.flush(&mut follower_vm).unwrap();
|
||||
follower_vm.execute_all(&mut follower_ctx).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let cf_hs_hash = [1u8; 32];
|
||||
let sf_hs_hash = [2u8; 32];
|
||||
|
||||
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
|
||||
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
|
||||
leader.set_cf_hash(cf_hs_hash).unwrap();
|
||||
follower.set_cf_hash(cf_hs_hash).unwrap();
|
||||
|
||||
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
|
||||
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
|
||||
while leader.wants_flush() || follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
leader.flush(&mut leader_vm).unwrap();
|
||||
leader_vm.execute_all(&mut leader_ctx).await
|
||||
},
|
||||
async {
|
||||
follower.flush(&mut follower_vm).unwrap();
|
||||
follower_vm.execute_all(&mut follower_ctx).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let _ = leader_vm.decode(leader_output.cf_vd).unwrap();
|
||||
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
|
||||
|
||||
let _ = follower_vm.decode(follower_output.cf_vd).unwrap();
|
||||
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
let sf_hs_hash = [2u8; 32];
|
||||
|
||||
leader.set_sf_hash(sf_hs_hash).unwrap();
|
||||
follower.set_sf_hash(sf_hs_hash).unwrap();
|
||||
|
||||
while leader.wants_flush() || follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
leader.flush(&mut leader_vm).unwrap();
|
||||
leader_vm.execute_all(&mut leader_ctx).await
|
||||
},
|
||||
async {
|
||||
follower.flush(&mut follower_vm).unwrap();
|
||||
follower_vm.execute_all(&mut follower_ctx).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let _ = leader_vm.decode(leader_output.sf_vd).unwrap();
|
||||
let _ = follower_vm.decode(follower_output.sf_vd).unwrap();
|
||||
}
|
||||
|
||||
@@ -1,24 +1,10 @@
|
||||
use derive_builder::Builder;
|
||||
//! PRF modes.
|
||||
|
||||
/// Role of this party in the PRF.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Role {
|
||||
/// The leader provides the private inputs to the PRF.
|
||||
Leader,
|
||||
/// The follower is blind to the inputs to the PRF.
|
||||
Follower,
|
||||
}
|
||||
|
||||
/// Configuration for the PRF.
|
||||
#[derive(Debug, Builder)]
|
||||
pub struct PrfConfig {
|
||||
/// The role of this party in the PRF.
|
||||
pub(crate) role: Role,
|
||||
}
|
||||
|
||||
impl PrfConfig {
|
||||
/// Creates a new builder.
|
||||
pub fn builder() -> PrfConfigBuilder {
|
||||
PrfConfigBuilder::default()
|
||||
}
|
||||
/// Modes for the PRF.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum Mode {
|
||||
/// Computes some hashes locally.
|
||||
Reduced,
|
||||
/// Computes the whole PRF in MPC.
|
||||
Normal,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use core::fmt;
|
||||
use std::error::Error;
|
||||
|
||||
use mpz_hash::sha256::Sha256Error;
|
||||
|
||||
/// A PRF error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub struct PrfError {
|
||||
@@ -20,22 +22,21 @@ impl PrfError {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
|
||||
Self::new(ErrorKind::Vm, err)
|
||||
}
|
||||
|
||||
pub(crate) fn state(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::State,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn role(msg: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind: ErrorKind::Role,
|
||||
source: Some(msg.into().into()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn vm<E: Into<Box<dyn Error + Send + Sync>>>(err: E) -> Self {
|
||||
Self::new(ErrorKind::Vm, err)
|
||||
impl From<Sha256Error> for PrfError {
|
||||
fn from(value: Sha256Error) -> Self {
|
||||
Self::new(ErrorKind::Hash, value)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,7 +44,7 @@ impl PrfError {
|
||||
pub(crate) enum ErrorKind {
|
||||
Vm,
|
||||
State,
|
||||
Role,
|
||||
Hash,
|
||||
}
|
||||
|
||||
impl fmt::Display for PrfError {
|
||||
@@ -51,7 +52,7 @@ impl fmt::Display for PrfError {
|
||||
match self.kind {
|
||||
ErrorKind::Vm => write!(f, "vm error")?,
|
||||
ErrorKind::State => write!(f, "state error")?,
|
||||
ErrorKind::Role => write!(f, "role error")?,
|
||||
ErrorKind::Hash => write!(f, "hash error")?,
|
||||
}
|
||||
|
||||
if let Some(ref source) = self.source {
|
||||
@@ -61,9 +62,3 @@ impl fmt::Display for PrfError {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<mpz_common::ContextError> for PrfError {
|
||||
fn from(error: mpz_common::ContextError) -> Self {
|
||||
Self::new(ErrorKind::Vm, error)
|
||||
}
|
||||
}
|
||||
|
||||
177
crates/components/hmac-sha256/src/hmac.rs
Normal file
177
crates/components/hmac-sha256/src/hmac.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
//! Computation of HMAC-SHA256.
|
||||
//!
|
||||
//! HMAC-SHA256 is defined as
|
||||
//!
|
||||
//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m))
|
||||
//!
|
||||
//! * H - SHA256 hash function
|
||||
//! * key' - key padded with zero bytes to 64 bytes (we do not support longer
|
||||
//! keys)
|
||||
//! * opad - 64 bytes of 0x5c
|
||||
//! * ipad - 64 bytes of 0x36
|
||||
//! * m - message
|
||||
//!
|
||||
//! This implementation computes HMAC-SHA256 using intermediate results
|
||||
//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial ||
|
||||
//! inner_local)
|
||||
//!
|
||||
//! * `outer_partial` - key' xor opad
|
||||
//! * `inner_local` - H((key' xor ipad) || m)
|
||||
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U8},
|
||||
Array,
|
||||
},
|
||||
Vm,
|
||||
};
|
||||
|
||||
use crate::PrfError;
|
||||
|
||||
pub(crate) const IPAD: [u8; 64] = [0x36; 64];
|
||||
pub(crate) const OPAD: [u8; 64] = [0x5c; 64];
|
||||
|
||||
/// Computes HMAC-SHA256
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - The virtual machine.
|
||||
/// * `outer_partial` - (key' xor opad)
|
||||
/// * `inner_local` - H((key' xor ipad) || m)
|
||||
pub(crate) fn hmac_sha256(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
mut outer_partial: Sha256,
|
||||
inner_local: Array<U8, 32>,
|
||||
) -> Result<Array<U8, 32>, PrfError> {
|
||||
outer_partial.update(&inner_local.into());
|
||||
outer_partial.compress(vm)?;
|
||||
outer_partial.finalize(vm).map_err(PrfError::from)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
hmac::hmac_sha256,
|
||||
sha256, state_to_bytes,
|
||||
test_utils::{compute_inner_local, compute_outer_partial, mock_vm},
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{U32, U8},
|
||||
Array, MemoryExt, ViewExt,
|
||||
},
|
||||
Execute,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_hmac_reference() {
|
||||
let (inputs, references) = test_fixtures();
|
||||
|
||||
for (input, &reference) in inputs.iter().zip(references.iter()) {
|
||||
let outer_partial = compute_outer_partial(input.0.clone());
|
||||
let inner_local = compute_inner_local(input.0.clone(), &input.1);
|
||||
|
||||
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
|
||||
|
||||
assert_eq!(state_to_bytes(hmac), reference);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hmac_circuit() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let (inputs, references) = test_fixtures();
|
||||
for (input, &reference) in inputs.iter().zip(references.iter()) {
|
||||
let outer_partial = compute_outer_partial(input.0.clone());
|
||||
let inner_local = compute_inner_local(input.0.clone(), &input.1);
|
||||
|
||||
let outer_partial_leader: Array<U32, 8> = leader.alloc().unwrap();
|
||||
leader.mark_public(outer_partial_leader).unwrap();
|
||||
leader.assign(outer_partial_leader, outer_partial).unwrap();
|
||||
leader.commit(outer_partial_leader).unwrap();
|
||||
|
||||
let inner_local_leader: Array<U8, 32> = leader.alloc().unwrap();
|
||||
leader.mark_public(inner_local_leader).unwrap();
|
||||
leader
|
||||
.assign(inner_local_leader, state_to_bytes(inner_local))
|
||||
.unwrap();
|
||||
leader.commit(inner_local_leader).unwrap();
|
||||
|
||||
let hmac_leader = hmac_sha256(
|
||||
&mut leader,
|
||||
Sha256::new_from_state(outer_partial_leader, 1),
|
||||
inner_local_leader,
|
||||
)
|
||||
.unwrap();
|
||||
let hmac_leader = leader.decode(hmac_leader).unwrap();
|
||||
|
||||
let outer_partial_follower: Array<U32, 8> = follower.alloc().unwrap();
|
||||
follower.mark_public(outer_partial_follower).unwrap();
|
||||
follower
|
||||
.assign(outer_partial_follower, outer_partial)
|
||||
.unwrap();
|
||||
follower.commit(outer_partial_follower).unwrap();
|
||||
|
||||
let inner_local_follower: Array<U8, 32> = follower.alloc().unwrap();
|
||||
follower.mark_public(inner_local_follower).unwrap();
|
||||
follower
|
||||
.assign(inner_local_follower, state_to_bytes(inner_local))
|
||||
.unwrap();
|
||||
follower.commit(inner_local_follower).unwrap();
|
||||
|
||||
let hmac_follower = hmac_sha256(
|
||||
&mut follower,
|
||||
Sha256::new_from_state(outer_partial_follower, 1),
|
||||
inner_local_follower,
|
||||
)
|
||||
.unwrap();
|
||||
let hmac_follower = follower.decode(hmac_follower).unwrap();
|
||||
|
||||
let (hmac_leader, hmac_follower) = tokio::try_join!(
|
||||
async {
|
||||
leader.execute_all(&mut ctx_a).await.unwrap();
|
||||
hmac_leader.await
|
||||
},
|
||||
async {
|
||||
follower.execute_all(&mut ctx_b).await.unwrap();
|
||||
hmac_follower.await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(hmac_leader, hmac_follower);
|
||||
assert_eq!(hmac_leader, reference);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn test_fixtures() -> (Vec<(Vec<u8>, Vec<u8>)>, Vec<[u8; 32]>) {
|
||||
let test_vectors: Vec<(Vec<u8>, Vec<u8>)> = vec![
|
||||
(
|
||||
hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(),
|
||||
hex::decode("4869205468657265").unwrap(),
|
||||
),
|
||||
(
|
||||
hex::decode("4a656665").unwrap(),
|
||||
hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(),
|
||||
),
|
||||
];
|
||||
let expected: Vec<[u8; 32]> = vec![
|
||||
hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7")
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843")
|
||||
.unwrap()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
];
|
||||
|
||||
(test_vectors, expected)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +1,24 @@
|
||||
//! This module contains the protocol for computing TLS SHA-256 HMAC PRF.
|
||||
//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF.
|
||||
|
||||
#![deny(missing_docs, unreachable_pub, unused_must_use)]
|
||||
#![deny(clippy::all)]
|
||||
#![forbid(unsafe_code)]
|
||||
|
||||
mod config;
|
||||
mod error;
|
||||
mod prf;
|
||||
mod hmac;
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
|
||||
pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role};
|
||||
mod config;
|
||||
pub use config::Mode;
|
||||
|
||||
mod error;
|
||||
pub use error::PrfError;
|
||||
|
||||
mod prf;
|
||||
pub use prf::MpcPrf;
|
||||
|
||||
use mpz_vm_core::memory::{binary::U8, Array};
|
||||
|
||||
pub(crate) static CF_LABEL: &[u8] = b"client finished";
|
||||
pub(crate) static SF_LABEL: &[u8] = b"server finished";
|
||||
|
||||
/// Builds the circuits for the PRF.
|
||||
///
|
||||
/// This function can be used ahead of time to build the circuits for the PRF,
|
||||
/// which at the moment is CPU and memory intensive.
|
||||
pub async fn build_circuits() {
|
||||
prf::Circuits::get().await;
|
||||
}
|
||||
|
||||
/// PRF output.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PrfOutput {
|
||||
@@ -49,176 +43,227 @@ pub struct SessionKeys {
|
||||
pub server_iv: Array<U8, 4>,
|
||||
}
|
||||
|
||||
fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] {
|
||||
use sha2::{
|
||||
compress256,
|
||||
digest::{
|
||||
block_buffer::{BlockBuffer, Eager},
|
||||
generic_array::typenum::U64,
|
||||
},
|
||||
};
|
||||
|
||||
let mut buffer = BlockBuffer::<U64, Eager>::default();
|
||||
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
|
||||
buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| {
|
||||
compress256(&mut state, &[*b])
|
||||
});
|
||||
state
|
||||
}
|
||||
|
||||
fn state_to_bytes(input: [u32; 8]) -> [u8; 32] {
|
||||
let mut output = [0_u8; 32];
|
||||
for (k, byte_chunk) in input.iter().enumerate() {
|
||||
let byte_chunk = byte_chunk.to_be_bytes();
|
||||
output[4 * k..4 * (k + 1)].copy_from_slice(&byte_chunk);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd},
|
||||
Mode, MpcPrf, SessionKeys,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
};
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys};
|
||||
use mpz_ot::ideal::cot::ideal_cot;
|
||||
use mpz_vm_core::{memory::correlated::Delta, prelude::*};
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
use rand06_compat::Rand0_6CompatExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn compute_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
|
||||
let (outer_state, inner_state) = hmac_sha256_partial(&pms);
|
||||
let seed = client_random
|
||||
.iter()
|
||||
.chain(&server_random)
|
||||
.copied()
|
||||
.collect::<Vec<_>>();
|
||||
let ms = prf(outer_state, inner_state, &seed, b"master secret", 48);
|
||||
ms.try_into().unwrap()
|
||||
}
|
||||
|
||||
fn compute_vd(ms: [u8; 48], label: &[u8], hs_hash: [u8; 32]) -> [u8; 12] {
|
||||
let (outer_state, inner_state) = hmac_sha256_partial(&ms);
|
||||
let vd = prf(outer_state, inner_state, &hs_hash, label, 12);
|
||||
vd.try_into().unwrap()
|
||||
}
|
||||
|
||||
#[ignore = "expensive"]
|
||||
#[tokio::test]
|
||||
async fn test_prf() {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
async fn test_prf_reduced() {
|
||||
let mode = Mode::Reduced;
|
||||
test_prf(mode).await;
|
||||
}
|
||||
|
||||
let pms = [42u8; 32];
|
||||
let client_random = [69u8; 32];
|
||||
let server_random: [u8; 32] = [96u8; 32];
|
||||
let ms = compute_ms(pms, client_random, server_random);
|
||||
#[tokio::test]
|
||||
async fn test_prf_normal() {
|
||||
let mode = Mode::Normal;
|
||||
test_prf(mode).await;
|
||||
}
|
||||
|
||||
let (mut leader_ctx, mut follower_ctx) = test_st_context(128);
|
||||
async fn test_prf(mode: Mode) {
|
||||
let mut rng = StdRng::seed_from_u64(1);
|
||||
// Test input
|
||||
let pms: [u8; 32] = rng.random();
|
||||
let client_random: [u8; 32] = rng.random();
|
||||
let server_random: [u8; 32] = rng.random();
|
||||
|
||||
let delta = Delta::random(&mut rng.compat_by_ref());
|
||||
let (ot_send, ot_recv) = ideal_cot(delta.into_inner());
|
||||
let cf_hs_hash: [u8; 32] = rng.random();
|
||||
let sf_hs_hash: [u8; 32] = rng.random();
|
||||
|
||||
let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta);
|
||||
let mut follower_vm = Evaluator::new(ot_recv);
|
||||
// Expected output
|
||||
let ms_expected = prf_ms(pms, client_random, server_random);
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader_vm.alloc().unwrap();
|
||||
leader_vm.mark_public(leader_pms).unwrap();
|
||||
leader_vm.assign(leader_pms, pms).unwrap();
|
||||
leader_vm.commit(leader_pms).unwrap();
|
||||
let [cwk_expected, swk_expected, civ_expected, siv_expected] =
|
||||
prf_keys(ms_expected, client_random, server_random);
|
||||
|
||||
let follower_pms: Array<U8, 32> = follower_vm.alloc().unwrap();
|
||||
follower_vm.mark_public(follower_pms).unwrap();
|
||||
follower_vm.assign(follower_pms, pms).unwrap();
|
||||
follower_vm.commit(follower_pms).unwrap();
|
||||
let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap();
|
||||
let swk_expected: [u8; 16] = swk_expected.try_into().unwrap();
|
||||
let civ_expected: [u8; 4] = civ_expected.try_into().unwrap();
|
||||
let siv_expected: [u8; 4] = siv_expected.try_into().unwrap();
|
||||
|
||||
let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap());
|
||||
let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap());
|
||||
let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash);
|
||||
let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash);
|
||||
|
||||
let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap();
|
||||
let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap();
|
||||
let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap();
|
||||
let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap();
|
||||
|
||||
leader
|
||||
.set_client_random(&mut leader_vm, Some(client_random))
|
||||
// Set up vm and prf
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(128);
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let leader_pms: Array<U8, 32> = leader.alloc().unwrap();
|
||||
leader.mark_public(leader_pms).unwrap();
|
||||
leader.assign(leader_pms, pms).unwrap();
|
||||
leader.commit(leader_pms).unwrap();
|
||||
|
||||
let follower_pms: Array<U8, 32> = follower.alloc().unwrap();
|
||||
follower.mark_public(follower_pms).unwrap();
|
||||
follower.assign(follower_pms, pms).unwrap();
|
||||
follower.commit(follower_pms).unwrap();
|
||||
|
||||
let mut prf_leader = MpcPrf::new(mode);
|
||||
let mut prf_follower = MpcPrf::new(mode);
|
||||
|
||||
let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap();
|
||||
let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap();
|
||||
|
||||
// client_random and server_random
|
||||
prf_leader.set_client_random(client_random).unwrap();
|
||||
prf_follower.set_client_random(client_random).unwrap();
|
||||
|
||||
prf_leader.set_server_random(server_random).unwrap();
|
||||
prf_follower.set_server_random(server_random).unwrap();
|
||||
|
||||
let SessionKeys {
|
||||
client_write_key: cwk_leader,
|
||||
server_write_key: swk_leader,
|
||||
client_iv: civ_leader,
|
||||
server_iv: siv_leader,
|
||||
} = leader_prf_out.keys;
|
||||
|
||||
let mut cwk_leader = leader.decode(cwk_leader).unwrap();
|
||||
let mut swk_leader = leader.decode(swk_leader).unwrap();
|
||||
let mut civ_leader = leader.decode(civ_leader).unwrap();
|
||||
let mut siv_leader = leader.decode(siv_leader).unwrap();
|
||||
|
||||
let SessionKeys {
|
||||
client_write_key: cwk_follower,
|
||||
server_write_key: swk_follower,
|
||||
client_iv: civ_follower,
|
||||
server_iv: siv_follower,
|
||||
} = follower_prf_out.keys;
|
||||
|
||||
let mut cwk_follower = follower.decode(cwk_follower).unwrap();
|
||||
let mut swk_follower = follower.decode(swk_follower).unwrap();
|
||||
let mut civ_follower = follower.decode(civ_follower).unwrap();
|
||||
let mut siv_follower = follower.decode(siv_follower).unwrap();
|
||||
|
||||
while prf_leader.wants_flush() || prf_follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
prf_leader.flush(&mut leader).unwrap();
|
||||
leader.execute_all(&mut ctx_a).await
|
||||
},
|
||||
async {
|
||||
prf_follower.flush(&mut follower).unwrap();
|
||||
follower.execute_all(&mut ctx_b).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
follower.set_client_random(&mut follower_vm, None).unwrap();
|
||||
}
|
||||
|
||||
leader
|
||||
.set_server_random(&mut leader_vm, server_random)
|
||||
let cwk_leader = cwk_leader.try_recv().unwrap().unwrap();
|
||||
let swk_leader = swk_leader.try_recv().unwrap().unwrap();
|
||||
let civ_leader = civ_leader.try_recv().unwrap().unwrap();
|
||||
let siv_leader = siv_leader.try_recv().unwrap().unwrap();
|
||||
|
||||
let cwk_follower = cwk_follower.try_recv().unwrap().unwrap();
|
||||
let swk_follower = swk_follower.try_recv().unwrap().unwrap();
|
||||
let civ_follower = civ_follower.try_recv().unwrap().unwrap();
|
||||
let siv_follower = siv_follower.try_recv().unwrap().unwrap();
|
||||
|
||||
assert_eq!(cwk_leader, cwk_follower);
|
||||
assert_eq!(swk_leader, swk_follower);
|
||||
assert_eq!(civ_leader, civ_follower);
|
||||
assert_eq!(siv_leader, siv_follower);
|
||||
|
||||
assert_eq!(cwk_leader, cwk_expected);
|
||||
assert_eq!(swk_leader, swk_expected);
|
||||
assert_eq!(civ_leader, civ_expected);
|
||||
assert_eq!(siv_leader, siv_expected);
|
||||
|
||||
// client finished
|
||||
prf_leader.set_cf_hash(cf_hs_hash).unwrap();
|
||||
prf_follower.set_cf_hash(cf_hs_hash).unwrap();
|
||||
|
||||
let cf_vd_leader = leader_prf_out.cf_vd;
|
||||
let cf_vd_follower = follower_prf_out.cf_vd;
|
||||
|
||||
let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap();
|
||||
let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap();
|
||||
|
||||
while prf_leader.wants_flush() || prf_follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
prf_leader.flush(&mut leader).unwrap();
|
||||
leader.execute_all(&mut ctx_a).await
|
||||
},
|
||||
async {
|
||||
prf_follower.flush(&mut follower).unwrap();
|
||||
follower.execute_all(&mut ctx_b).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
follower
|
||||
.set_server_random(&mut follower_vm, server_random)
|
||||
}
|
||||
|
||||
let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap();
|
||||
let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap();
|
||||
|
||||
assert_eq!(cf_vd_leader, cf_vd_follower);
|
||||
assert_eq!(cf_vd_leader, cf_vd_expected);
|
||||
|
||||
// server finished
|
||||
prf_leader.set_sf_hash(sf_hs_hash).unwrap();
|
||||
prf_follower.set_sf_hash(sf_hs_hash).unwrap();
|
||||
|
||||
let sf_vd_leader = leader_prf_out.sf_vd;
|
||||
let sf_vd_follower = follower_prf_out.sf_vd;
|
||||
|
||||
let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap();
|
||||
let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap();
|
||||
|
||||
while prf_leader.wants_flush() || prf_follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
prf_leader.flush(&mut leader).unwrap();
|
||||
leader.execute_all(&mut ctx_a).await
|
||||
},
|
||||
async {
|
||||
prf_follower.flush(&mut follower).unwrap();
|
||||
follower.execute_all(&mut ctx_b).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let leader_cwk = leader_vm
|
||||
.decode(leader_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
let leader_swk = leader_vm
|
||||
.decode(leader_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap();
|
||||
let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap();
|
||||
let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap();
|
||||
let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap();
|
||||
|
||||
let follower_cwk = follower_vm
|
||||
.decode(follower_output.keys.client_write_key)
|
||||
.unwrap();
|
||||
let follower_swk = follower_vm
|
||||
.decode(follower_output.keys.server_write_key)
|
||||
.unwrap();
|
||||
let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap();
|
||||
let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
|
||||
let leader_cwk = leader_cwk.await.unwrap();
|
||||
let leader_swk = leader_swk.await.unwrap();
|
||||
let leader_civ = leader_civ.await.unwrap();
|
||||
let leader_siv = leader_siv.await.unwrap();
|
||||
|
||||
let follower_cwk = follower_cwk.await.unwrap();
|
||||
let follower_swk = follower_swk.await.unwrap();
|
||||
let follower_civ = follower_civ.await.unwrap();
|
||||
let follower_siv = follower_siv.await.unwrap();
|
||||
|
||||
let (expected_cwk, expected_swk, expected_civ, expected_siv) =
|
||||
session_keys(pms, client_random, server_random);
|
||||
|
||||
assert_eq!(leader_cwk, expected_cwk);
|
||||
assert_eq!(leader_swk, expected_swk);
|
||||
assert_eq!(leader_civ, expected_civ);
|
||||
assert_eq!(leader_siv, expected_siv);
|
||||
|
||||
assert_eq!(follower_cwk, expected_cwk);
|
||||
assert_eq!(follower_swk, expected_swk);
|
||||
assert_eq!(follower_civ, expected_civ);
|
||||
assert_eq!(follower_siv, expected_siv);
|
||||
|
||||
let cf_hs_hash = [1u8; 32];
|
||||
let sf_hs_hash = [2u8; 32];
|
||||
|
||||
leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap();
|
||||
leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap();
|
||||
|
||||
follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap();
|
||||
follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap();
|
||||
|
||||
let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap();
|
||||
let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap();
|
||||
|
||||
let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap();
|
||||
let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap();
|
||||
|
||||
futures::join!(
|
||||
async {
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.execute(&mut leader_ctx).await.unwrap();
|
||||
leader_vm.flush(&mut leader_ctx).await.unwrap();
|
||||
},
|
||||
async {
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.execute(&mut follower_ctx).await.unwrap();
|
||||
follower_vm.flush(&mut follower_ctx).await.unwrap();
|
||||
}
|
||||
);
|
||||
|
||||
let leader_cf_vd = leader_cf_vd.await.unwrap();
|
||||
let leader_sf_vd = leader_sf_vd.await.unwrap();
|
||||
|
||||
let follower_cf_vd = follower_cf_vd.await.unwrap();
|
||||
let follower_sf_vd = follower_sf_vd.await.unwrap();
|
||||
|
||||
let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash);
|
||||
let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash);
|
||||
|
||||
assert_eq!(leader_cf_vd, expected_cf_vd);
|
||||
assert_eq!(leader_sf_vd, expected_sf_vd);
|
||||
assert_eq!(follower_cf_vd, expected_cf_vd);
|
||||
assert_eq!(follower_sf_vd, expected_sf_vd);
|
||||
assert_eq!(sf_vd_leader, sf_vd_follower);
|
||||
assert_eq!(sf_vd_leader, sf_vd_expected);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,98 +1,41 @@
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
sync::{Arc, OnceLock},
|
||||
use crate::{
|
||||
hmac::{IPAD, OPAD},
|
||||
Mode, PrfError, PrfOutput,
|
||||
};
|
||||
|
||||
use hmac_sha256_circuits::{build_session_keys, build_verify_data};
|
||||
use mpz_circuits::Circuit;
|
||||
use mpz_common::cpu::CpuBackend;
|
||||
use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder};
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U32, U8},
|
||||
Array,
|
||||
binary::{Binary, U8},
|
||||
Array, MemoryExt, StaticSize, Vector, ViewExt,
|
||||
},
|
||||
prelude::*,
|
||||
Call, Vm,
|
||||
Call, CallableExt, Vm,
|
||||
};
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL};
|
||||
mod state;
|
||||
use state::State;
|
||||
|
||||
pub(crate) struct Circuits {
|
||||
session_keys: Arc<Circuit>,
|
||||
client_vd: Arc<Circuit>,
|
||||
server_vd: Arc<Circuit>,
|
||||
}
|
||||
|
||||
impl Circuits {
|
||||
pub(crate) async fn get() -> &'static Self {
|
||||
static CIRCUITS: OnceLock<Circuits> = OnceLock::new();
|
||||
if let Some(circuits) = CIRCUITS.get() {
|
||||
return circuits;
|
||||
}
|
||||
|
||||
let (session_keys, client_vd, server_vd) = futures::join!(
|
||||
CpuBackend::blocking(build_session_keys),
|
||||
CpuBackend::blocking(|| build_verify_data(CF_LABEL)),
|
||||
CpuBackend::blocking(|| build_verify_data(SF_LABEL)),
|
||||
);
|
||||
|
||||
_ = CIRCUITS.set(Circuits {
|
||||
session_keys,
|
||||
client_vd,
|
||||
server_vd,
|
||||
});
|
||||
|
||||
CIRCUITS.get().unwrap()
|
||||
}
|
||||
}
|
||||
mod function;
|
||||
use function::Prf;
|
||||
|
||||
/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum State {
|
||||
Initialized,
|
||||
SessionKeys {
|
||||
client_random: Array<U8, 32>,
|
||||
server_random: Array<U8, 32>,
|
||||
cf_hash: Array<U8, 32>,
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
ClientFinished {
|
||||
cf_hash: Array<U8, 32>,
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
ServerFinished {
|
||||
sf_hash: Array<U8, 32>,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn take(&mut self) -> State {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
}
|
||||
|
||||
/// MPC PRF for computing TLS HMAC-SHA256 PRF.
|
||||
pub struct MpcPrf {
|
||||
config: PrfConfig,
|
||||
mode: Mode,
|
||||
state: State,
|
||||
}
|
||||
|
||||
impl Debug for MpcPrf {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MpcPrf")
|
||||
.field("config", &self.config)
|
||||
.field("state", &self.state)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl MpcPrf {
|
||||
/// Creates a new instance of the PRF.
|
||||
pub fn new(config: PrfConfig) -> MpcPrf {
|
||||
MpcPrf {
|
||||
config,
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// `mode` - The PRF mode.
|
||||
pub fn new(mode: Mode) -> MpcPrf {
|
||||
Self {
|
||||
mode,
|
||||
state: State::Initialized,
|
||||
}
|
||||
}
|
||||
@@ -113,122 +56,58 @@ impl MpcPrf {
|
||||
return Err(PrfError::state("PRF not in initialized state"));
|
||||
};
|
||||
|
||||
let circuits = futures::executor::block_on(Circuits::get());
|
||||
let mode = self.mode;
|
||||
let pms: Vector<U8> = pms.into();
|
||||
|
||||
let client_random = vm.alloc().map_err(PrfError::vm)?;
|
||||
let server_random = vm.alloc().map_err(PrfError::vm)?;
|
||||
let outer_partial_pms = compute_partial(vm, pms, OPAD)?;
|
||||
let inner_partial_pms = compute_partial(vm, pms, IPAD)?;
|
||||
|
||||
// The client random is kept private so that the handshake transcript
|
||||
// hashes do not leak information about the server's identity.
|
||||
match self.config.role {
|
||||
Role::Leader => vm.mark_private(client_random),
|
||||
Role::Follower => vm.mark_blind(client_random),
|
||||
}
|
||||
.map_err(PrfError::vm)?;
|
||||
let master_secret =
|
||||
Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?;
|
||||
let ms = master_secret.output();
|
||||
let ms = merge_outputs(vm, ms, 48)?;
|
||||
|
||||
vm.mark_public(server_random).map_err(PrfError::vm)?;
|
||||
let outer_partial_ms = compute_partial(vm, ms, OPAD)?;
|
||||
let inner_partial_ms = compute_partial(vm, ms, IPAD)?;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let (
|
||||
client_write_key,
|
||||
server_write_key,
|
||||
client_iv,
|
||||
server_iv,
|
||||
ms_outer_hash_state,
|
||||
ms_inner_hash_state,
|
||||
): (
|
||||
Array<U8, 16>,
|
||||
Array<U8, 16>,
|
||||
Array<U8, 4>,
|
||||
Array<U8, 4>,
|
||||
Array<U32, 8>,
|
||||
Array<U32, 8>,
|
||||
) = vm
|
||||
.call(
|
||||
Call::builder(circuits.session_keys.clone())
|
||||
.arg(pms)
|
||||
.arg(client_random)
|
||||
.arg(server_random)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
let keys = SessionKeys {
|
||||
client_write_key,
|
||||
server_write_key,
|
||||
client_iv,
|
||||
server_iv,
|
||||
};
|
||||
|
||||
let cf_hash = vm.alloc().map_err(PrfError::vm)?;
|
||||
vm.mark_public(cf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
let cf_vd = vm
|
||||
.call(
|
||||
Call::builder(circuits.client_vd.clone())
|
||||
.arg(ms_outer_hash_state)
|
||||
.arg(ms_inner_hash_state)
|
||||
.arg(cf_hash)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
|
||||
let sf_hash = vm.alloc().map_err(PrfError::vm)?;
|
||||
vm.mark_public(sf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
let sf_vd = vm
|
||||
.call(
|
||||
Call::builder(circuits.server_vd.clone())
|
||||
.arg(ms_outer_hash_state)
|
||||
.arg(ms_inner_hash_state)
|
||||
.arg(sf_hash)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?,
|
||||
)
|
||||
.map_err(PrfError::vm)?;
|
||||
let key_expansion =
|
||||
Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?;
|
||||
let client_finished = Prf::alloc_client_finished(
|
||||
mode,
|
||||
vm,
|
||||
outer_partial_ms.clone(),
|
||||
inner_partial_ms.clone(),
|
||||
)?;
|
||||
let server_finished = Prf::alloc_server_finished(
|
||||
mode,
|
||||
vm,
|
||||
outer_partial_ms.clone(),
|
||||
inner_partial_ms.clone(),
|
||||
)?;
|
||||
|
||||
self.state = State::SessionKeys {
|
||||
client_random,
|
||||
server_random,
|
||||
cf_hash,
|
||||
sf_hash,
|
||||
client_random: None,
|
||||
master_secret,
|
||||
key_expansion,
|
||||
client_finished,
|
||||
server_finished,
|
||||
};
|
||||
|
||||
Ok(PrfOutput { keys, cf_vd, sf_vd })
|
||||
self.state.prf_output(vm)
|
||||
}
|
||||
|
||||
/// Sets the client random.
|
||||
///
|
||||
/// Only the leader can provide the client random.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `client_random` - The client random.
|
||||
/// * `random` - The client random.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn set_client_random(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
random: Option<[u8; 32]>,
|
||||
) -> Result<(), PrfError> {
|
||||
let State::SessionKeys { client_random, .. } = &self.state else {
|
||||
pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
|
||||
let State::SessionKeys { client_random, .. } = &mut self.state else {
|
||||
return Err(PrfError::state("PRF not set up"));
|
||||
};
|
||||
|
||||
if self.config.role == Role::Leader {
|
||||
let Some(random) = random else {
|
||||
return Err(PrfError::role("leader must provide client random"));
|
||||
};
|
||||
|
||||
vm.assign(*client_random, random).map_err(PrfError::vm)?;
|
||||
} else if random.is_some() {
|
||||
return Err(PrfError::role("only leader can set client random"));
|
||||
}
|
||||
|
||||
vm.commit(*client_random).map_err(PrfError::vm)?;
|
||||
|
||||
*client_random = Some(random);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -236,28 +115,29 @@ impl MpcPrf {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `server_random` - The server random.
|
||||
/// * `random` - The server random.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn set_server_random(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
random: [u8; 32],
|
||||
) -> Result<(), PrfError> {
|
||||
pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> {
|
||||
let State::SessionKeys {
|
||||
server_random,
|
||||
cf_hash,
|
||||
sf_hash,
|
||||
client_random,
|
||||
master_secret,
|
||||
key_expansion,
|
||||
..
|
||||
} = self.state.take()
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(PrfError::state("PRF not set up"));
|
||||
};
|
||||
|
||||
vm.assign(server_random, random).map_err(PrfError::vm)?;
|
||||
vm.commit(server_random).map_err(PrfError::vm)?;
|
||||
let client_random = client_random.expect("Client random should have been set by now");
|
||||
let server_random = random;
|
||||
|
||||
self.state = State::ClientFinished { cf_hash, sf_hash };
|
||||
let mut seed_ms = client_random.to_vec();
|
||||
seed_ms.extend_from_slice(&server_random);
|
||||
master_secret.set_start_seed(seed_ms);
|
||||
|
||||
let mut seed_ke = server_random.to_vec();
|
||||
seed_ke.extend_from_slice(&client_random);
|
||||
key_expansion.set_start_seed(seed_ke);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -266,22 +146,18 @@ impl MpcPrf {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn set_cf_hash(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<(), PrfError> {
|
||||
let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else {
|
||||
pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
|
||||
let State::ClientFinished {
|
||||
client_finished, ..
|
||||
} = &mut self.state
|
||||
else {
|
||||
return Err(PrfError::state("PRF not in client finished state"));
|
||||
};
|
||||
|
||||
vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?;
|
||||
vm.commit(cf_hash).map_err(PrfError::vm)?;
|
||||
|
||||
self.state = State::ServerFinished { sf_hash };
|
||||
let seed_cf = handshake_hash.to_vec();
|
||||
client_finished.set_start_seed(seed_cf);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -290,23 +166,242 @@ impl MpcPrf {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `handshake_hash` - The handshake transcript hash.
|
||||
#[instrument(level = "debug", skip_all, err)]
|
||||
pub fn set_sf_hash(
|
||||
&mut self,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
handshake_hash: [u8; 32],
|
||||
) -> Result<(), PrfError> {
|
||||
let State::ServerFinished { sf_hash } = self.state.take() else {
|
||||
pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> {
|
||||
let State::ServerFinished { server_finished } = &mut self.state else {
|
||||
return Err(PrfError::state("PRF not in server finished state"));
|
||||
};
|
||||
|
||||
vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?;
|
||||
vm.commit(sf_hash).map_err(PrfError::vm)?;
|
||||
let seed_sf = handshake_hash.to_vec();
|
||||
server_finished.set_start_seed(seed_sf);
|
||||
|
||||
self.state = State::Complete;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns if the PRF needs to be flushed.
|
||||
pub fn wants_flush(&self) -> bool {
|
||||
match &self.state {
|
||||
State::Initialized => false,
|
||||
State::SessionKeys {
|
||||
master_secret,
|
||||
key_expansion,
|
||||
..
|
||||
} => master_secret.wants_flush() || key_expansion.wants_flush(),
|
||||
State::ClientFinished {
|
||||
client_finished, ..
|
||||
} => client_finished.wants_flush(),
|
||||
State::ServerFinished { server_finished } => server_finished.wants_flush(),
|
||||
State::Complete => false,
|
||||
State::Error => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Flushes the PRF.
|
||||
pub fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
|
||||
self.state = match self.state.take() {
|
||||
State::SessionKeys {
|
||||
client_random,
|
||||
mut master_secret,
|
||||
mut key_expansion,
|
||||
client_finished,
|
||||
server_finished,
|
||||
} => {
|
||||
master_secret.flush(vm)?;
|
||||
key_expansion.flush(vm)?;
|
||||
|
||||
if !master_secret.wants_flush() && !key_expansion.wants_flush() {
|
||||
State::ClientFinished {
|
||||
client_finished,
|
||||
server_finished,
|
||||
}
|
||||
} else {
|
||||
State::SessionKeys {
|
||||
client_random,
|
||||
master_secret,
|
||||
key_expansion,
|
||||
client_finished,
|
||||
server_finished,
|
||||
}
|
||||
}
|
||||
}
|
||||
State::ClientFinished {
|
||||
mut client_finished,
|
||||
server_finished,
|
||||
} => {
|
||||
client_finished.flush(vm)?;
|
||||
|
||||
if !client_finished.wants_flush() {
|
||||
State::ServerFinished { server_finished }
|
||||
} else {
|
||||
State::ClientFinished {
|
||||
client_finished,
|
||||
server_finished,
|
||||
}
|
||||
}
|
||||
}
|
||||
State::ServerFinished {
|
||||
mut server_finished,
|
||||
} => {
|
||||
server_finished.flush(vm)?;
|
||||
|
||||
if !server_finished.wants_flush() {
|
||||
State::Complete
|
||||
} else {
|
||||
State::ServerFinished { server_finished }
|
||||
}
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Depending on the provided `mask` computes and returns `outer_partial` or
|
||||
/// `inner_partial` for HMAC-SHA256.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vm` - Virtual machine.
|
||||
/// * `key` - Key to pad and xor.
|
||||
/// * `mask`- Mask used for padding.
|
||||
fn compute_partial(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
key: Vector<U8>,
|
||||
mask: [u8; 64],
|
||||
) -> Result<Sha256, PrfError> {
|
||||
let xor = Arc::new(xor(8 * 64));
|
||||
|
||||
let additional_len = 64 - key.len();
|
||||
let padding = vec![0_u8; additional_len];
|
||||
|
||||
let padding_ref: Vector<U8> = vm.alloc_vec(additional_len).map_err(PrfError::vm)?;
|
||||
vm.mark_public(padding_ref).map_err(PrfError::vm)?;
|
||||
vm.assign(padding_ref, padding).map_err(PrfError::vm)?;
|
||||
vm.commit(padding_ref).map_err(PrfError::vm)?;
|
||||
|
||||
let mask_ref: Array<U8, 64> = vm.alloc().map_err(PrfError::vm)?;
|
||||
vm.mark_public(mask_ref).map_err(PrfError::vm)?;
|
||||
vm.assign(mask_ref, mask).map_err(PrfError::vm)?;
|
||||
vm.commit(mask_ref).map_err(PrfError::vm)?;
|
||||
|
||||
let xor = Call::builder(xor)
|
||||
.arg(key)
|
||||
.arg(padding_ref)
|
||||
.arg(mask_ref)
|
||||
.build()
|
||||
.map_err(PrfError::vm)?;
|
||||
let key_padded: Vector<U8> = vm.call(xor).map_err(PrfError::vm)?;
|
||||
|
||||
let mut sha = Sha256::new_with_init(vm)?;
|
||||
sha.update(&key_padded);
|
||||
sha.compress(vm)?;
|
||||
Ok(sha)
|
||||
}
|
||||
|
||||
fn merge_outputs(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
inputs: Vec<Array<U8, 32>>,
|
||||
output_bytes: usize,
|
||||
) -> Result<Vector<U8>, PrfError> {
|
||||
assert!(output_bytes <= 32 * inputs.len());
|
||||
|
||||
let bits = Array::<U8, 32>::SIZE * inputs.len();
|
||||
let circ = gen_merge_circ(bits);
|
||||
|
||||
let mut builder = Call::builder(circ);
|
||||
for &input in inputs.iter() {
|
||||
builder = builder.arg(input);
|
||||
}
|
||||
let call = builder.build().map_err(PrfError::vm)?;
|
||||
|
||||
let mut output: Vector<U8> = vm.call(call).map_err(PrfError::vm)?;
|
||||
output.truncate(output_bytes);
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn gen_merge_circ(size: usize) -> Arc<Circuit> {
|
||||
let mut builder = CircuitBuilder::new();
|
||||
let inputs = (0..size).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
|
||||
for input in inputs.chunks_exact(8) {
|
||||
for byte in input.chunks_exact(8) {
|
||||
for &feed in byte.iter() {
|
||||
let output = builder.add_id_gate(feed);
|
||||
builder.add_output(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(builder.build().expect("merge circuit is valid"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{prf::merge_outputs, test_utils::mock_vm};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_merge_outputs() {
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let input1: [u8; 32] = std::array::from_fn(|i| i as u8);
|
||||
let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32);
|
||||
|
||||
let mut expected = input1.to_vec();
|
||||
expected.extend_from_slice(&input2);
|
||||
expected.truncate(48);
|
||||
|
||||
// leader
|
||||
let input1_leader: Array<U8, 32> = leader.alloc().unwrap();
|
||||
let input2_leader: Array<U8, 32> = leader.alloc().unwrap();
|
||||
|
||||
leader.mark_public(input1_leader).unwrap();
|
||||
leader.mark_public(input2_leader).unwrap();
|
||||
|
||||
leader.assign(input1_leader, input1).unwrap();
|
||||
leader.assign(input2_leader, input2).unwrap();
|
||||
|
||||
leader.commit(input1_leader).unwrap();
|
||||
leader.commit(input2_leader).unwrap();
|
||||
|
||||
let merged_leader =
|
||||
merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap();
|
||||
let mut merged_leader = leader.decode(merged_leader).unwrap();
|
||||
|
||||
// follower
|
||||
let input1_follower: Array<U8, 32> = follower.alloc().unwrap();
|
||||
let input2_follower: Array<U8, 32> = follower.alloc().unwrap();
|
||||
|
||||
follower.mark_public(input1_follower).unwrap();
|
||||
follower.mark_public(input2_follower).unwrap();
|
||||
|
||||
follower.assign(input1_follower, input1).unwrap();
|
||||
follower.assign(input2_follower, input2).unwrap();
|
||||
|
||||
follower.commit(input1_follower).unwrap();
|
||||
follower.commit(input2_follower).unwrap();
|
||||
|
||||
let merged_follower =
|
||||
merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap();
|
||||
let mut merged_follower = follower.decode(merged_follower).unwrap();
|
||||
|
||||
tokio::try_join!(
|
||||
leader.execute_all(&mut ctx_a),
|
||||
follower.execute_all(&mut ctx_b)
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let merged_leader = merged_leader.try_recv().unwrap().unwrap();
|
||||
let merged_follower = merged_follower.try_recv().unwrap().unwrap();
|
||||
|
||||
assert_eq!(merged_leader, merged_follower);
|
||||
assert_eq!(merged_leader, expected);
|
||||
}
|
||||
}
|
||||
|
||||
257
crates/components/hmac-sha256/src/prf/function.rs
Normal file
257
crates/components/hmac-sha256/src/prf/function.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
//! Provides [`Prf`], for computing the TLS 1.2 PRF.
|
||||
|
||||
use crate::{Mode, PrfError};
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U8},
|
||||
Array,
|
||||
},
|
||||
Vm,
|
||||
};
|
||||
|
||||
mod normal;
|
||||
mod reduced;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum Prf {
|
||||
Reduced(reduced::PrfFunction),
|
||||
Normal(normal::PrfFunction),
|
||||
}
|
||||
|
||||
impl Prf {
|
||||
pub(crate) fn alloc_master_secret(
|
||||
mode: Mode,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
let prf = match mode {
|
||||
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
};
|
||||
Ok(prf)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_key_expansion(
|
||||
mode: Mode,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
let prf = match mode {
|
||||
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
};
|
||||
Ok(prf)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_client_finished(
|
||||
config: Mode,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
let prf = match config {
|
||||
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
};
|
||||
Ok(prf)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_server_finished(
|
||||
config: Mode,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
let prf = match config {
|
||||
Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished(
|
||||
vm,
|
||||
outer_partial,
|
||||
inner_partial,
|
||||
)?),
|
||||
};
|
||||
Ok(prf)
|
||||
}
|
||||
|
||||
pub(crate) fn wants_flush(&self) -> bool {
|
||||
match self {
|
||||
Prf::Reduced(prf) => prf.wants_flush(),
|
||||
Prf::Normal(prf) => prf.wants_flush(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
|
||||
match self {
|
||||
Prf::Reduced(prf) => prf.flush(vm),
|
||||
Prf::Normal(prf) => prf.flush(vm),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
|
||||
match self {
|
||||
Prf::Reduced(prf) => prf.set_start_seed(seed),
|
||||
Prf::Normal(prf) => prf.set_start_seed(seed),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
|
||||
match self {
|
||||
Prf::Reduced(prf) => prf.output(),
|
||||
Prf::Normal(prf) => prf.output(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
prf::{compute_partial, function::Prf},
|
||||
test_utils::{mock_vm, phash},
|
||||
Mode,
|
||||
};
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_vm_core::{
|
||||
memory::{binary::U8, Array, MemoryExt, ViewExt},
|
||||
Execute,
|
||||
};
|
||||
use rand::{rngs::ThreadRng, Rng};
|
||||
|
||||
const IPAD: [u8; 64] = [0x36; 64];
|
||||
const OPAD: [u8; 64] = [0x5c; 64];
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_phash_reduced() {
|
||||
let mode = Mode::Reduced;
|
||||
test_phash(mode).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_phash_normal() {
|
||||
let mode = Mode::Normal;
|
||||
test_phash(mode).await;
|
||||
}
|
||||
|
||||
async fn test_phash(mode: Mode) {
|
||||
let mut rng = ThreadRng::default();
|
||||
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut leader, mut follower) = mock_vm();
|
||||
|
||||
let key: [u8; 32] = rng.random();
|
||||
let start_seed: Vec<u8> = vec![42; 64];
|
||||
|
||||
let mut label_seed = b"master secret".to_vec();
|
||||
label_seed.extend_from_slice(&start_seed);
|
||||
let iterations = 2;
|
||||
|
||||
let leader_key: Array<U8, 32> = leader.alloc().unwrap();
|
||||
leader.mark_public(leader_key).unwrap();
|
||||
leader.assign(leader_key, key).unwrap();
|
||||
leader.commit(leader_key).unwrap();
|
||||
|
||||
let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap();
|
||||
let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap();
|
||||
|
||||
let mut prf_leader = Prf::alloc_master_secret(
|
||||
mode,
|
||||
&mut leader,
|
||||
outer_partial_leader,
|
||||
inner_partial_leader,
|
||||
)
|
||||
.unwrap();
|
||||
prf_leader.set_start_seed(start_seed.clone());
|
||||
|
||||
let mut prf_out_leader = vec![];
|
||||
for p in prf_leader.output() {
|
||||
let p_out = leader.decode(p).unwrap();
|
||||
prf_out_leader.push(p_out)
|
||||
}
|
||||
|
||||
let follower_key: Array<U8, 32> = follower.alloc().unwrap();
|
||||
follower.mark_public(follower_key).unwrap();
|
||||
follower.assign(follower_key, key).unwrap();
|
||||
follower.commit(follower_key).unwrap();
|
||||
|
||||
let outer_partial_follower =
|
||||
compute_partial(&mut follower, follower_key.into(), OPAD).unwrap();
|
||||
let inner_partial_follower =
|
||||
compute_partial(&mut follower, follower_key.into(), IPAD).unwrap();
|
||||
|
||||
let mut prf_follower = Prf::alloc_master_secret(
|
||||
mode,
|
||||
&mut follower,
|
||||
outer_partial_follower,
|
||||
inner_partial_follower,
|
||||
)
|
||||
.unwrap();
|
||||
prf_follower.set_start_seed(start_seed.clone());
|
||||
|
||||
let mut prf_out_follower = vec![];
|
||||
for p in prf_follower.output() {
|
||||
let p_out = follower.decode(p).unwrap();
|
||||
prf_out_follower.push(p_out)
|
||||
}
|
||||
|
||||
while prf_leader.wants_flush() || prf_follower.wants_flush() {
|
||||
tokio::try_join!(
|
||||
async {
|
||||
prf_leader.flush(&mut leader).unwrap();
|
||||
leader.execute_all(&mut ctx_a).await
|
||||
},
|
||||
async {
|
||||
prf_follower.flush(&mut follower).unwrap();
|
||||
follower.execute_all(&mut ctx_b).await
|
||||
}
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(prf_out_leader.len(), 2);
|
||||
assert_eq!(prf_out_leader.len(), prf_out_follower.len());
|
||||
|
||||
let prf_result_leader: Vec<u8> = prf_out_leader
|
||||
.iter_mut()
|
||||
.flat_map(|p| p.try_recv().unwrap().unwrap())
|
||||
.collect();
|
||||
let prf_result_follower: Vec<u8> = prf_out_follower
|
||||
.iter_mut()
|
||||
.flat_map(|p| p.try_recv().unwrap().unwrap())
|
||||
.collect();
|
||||
|
||||
let expected = phash(key.to_vec(), &label_seed, iterations);
|
||||
|
||||
assert_eq!(prf_result_leader, prf_result_follower);
|
||||
assert_eq!(prf_result_leader, expected)
|
||||
}
|
||||
}
|
||||
174
crates/components/hmac-sha256/src/prf/function/normal.rs
Normal file
174
crates/components/hmac-sha256/src/prf/function/normal.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
//! Computes the whole PRF in MPC.
|
||||
|
||||
use crate::{hmac::hmac_sha256, PrfError};
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U8},
|
||||
Array, MemoryExt, Vector, ViewExt,
|
||||
},
|
||||
Vm,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PrfFunction {
|
||||
// The label, e.g. "master secret".
|
||||
label: &'static [u8],
|
||||
state: State,
|
||||
// The start seed and the label, e.g. client_random + server_random + "master_secret".
|
||||
start_seed_label: Option<Vec<u8>>,
|
||||
a: Vec<PHash>,
|
||||
p: Vec<PHash>,
|
||||
}
|
||||
|
||||
impl PrfFunction {
|
||||
const MS_LABEL: &[u8] = b"master secret";
|
||||
const KEY_LABEL: &[u8] = b"key expansion";
|
||||
const CF_LABEL: &[u8] = b"client finished";
|
||||
const SF_LABEL: &[u8] = b"server finished";
|
||||
|
||||
pub(crate) fn alloc_master_secret(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_key_expansion(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_client_finished(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_server_finished(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32)
|
||||
}
|
||||
|
||||
pub(crate) fn wants_flush(&self) -> bool {
|
||||
let is_computing = match self.state {
|
||||
State::Computing => true,
|
||||
State::Finished => false,
|
||||
};
|
||||
is_computing && self.start_seed_label.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
|
||||
if let State::Computing = self.state {
|
||||
let a = self.a.first().expect("prf should be allocated");
|
||||
let msg = *a.msg.first().expect("message for prf should be present");
|
||||
|
||||
let msg_value = self
|
||||
.start_seed_label
|
||||
.clone()
|
||||
.expect("Start seed should have been set");
|
||||
|
||||
vm.assign(msg, msg_value).map_err(PrfError::vm)?;
|
||||
vm.commit(msg).map_err(PrfError::vm)?;
|
||||
|
||||
self.state = State::Finished;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
|
||||
let mut start_seed_label = self.label.to_vec();
|
||||
start_seed_label.extend_from_slice(&seed);
|
||||
|
||||
self.start_seed_label = Some(start_seed_label);
|
||||
}
|
||||
|
||||
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
|
||||
self.p.iter().map(|p| p.output).collect()
|
||||
}
|
||||
|
||||
fn alloc(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
label: &'static [u8],
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
output_len: usize,
|
||||
seed_len: usize,
|
||||
) -> Result<Self, PrfError> {
|
||||
let mut prf = Self {
|
||||
label,
|
||||
state: State::Computing,
|
||||
start_seed_label: None,
|
||||
a: vec![],
|
||||
p: vec![],
|
||||
};
|
||||
|
||||
assert!(output_len > 0, "cannot compute 0 bytes for prf");
|
||||
|
||||
let iterations = output_len.div_ceil(32);
|
||||
|
||||
let msg_len_a = label.len() + seed_len;
|
||||
let seed_label_ref: Vector<U8> = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?;
|
||||
vm.mark_public(seed_label_ref).map_err(PrfError::vm)?;
|
||||
|
||||
let mut msg_a = seed_label_ref;
|
||||
for _ in 0..iterations {
|
||||
let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?;
|
||||
msg_a = Vector::<U8>::from(a.output);
|
||||
prf.a.push(a);
|
||||
|
||||
let p = PHash::alloc(
|
||||
vm,
|
||||
outer_partial.clone(),
|
||||
inner_partial.clone(),
|
||||
&[msg_a, seed_label_ref],
|
||||
)?;
|
||||
prf.p.push(p);
|
||||
}
|
||||
|
||||
Ok(prf)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum State {
|
||||
Computing,
|
||||
Finished,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PHash {
|
||||
msg: Vec<Vector<U8>>,
|
||||
output: Array<U8, 32>,
|
||||
}
|
||||
|
||||
impl PHash {
|
||||
fn alloc(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
msg: &[Vector<U8>],
|
||||
) -> Result<Self, PrfError> {
|
||||
let mut inner_local = inner_partial;
|
||||
|
||||
msg.iter().for_each(|m| inner_local.update(m));
|
||||
inner_local.compress(vm)?;
|
||||
let inner_local = inner_local.finalize(vm)?;
|
||||
|
||||
let output = hmac_sha256(vm, outer_partial, inner_local)?;
|
||||
let p_hash = Self {
|
||||
msg: msg.to_vec(),
|
||||
output,
|
||||
};
|
||||
Ok(p_hash)
|
||||
}
|
||||
}
|
||||
247
crates/components/hmac-sha256/src/prf/function/reduced.rs
Normal file
247
crates/components/hmac-sha256/src/prf/function/reduced.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Computes some hashes of the PRF locally.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError};
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_hash::sha256::Sha256;
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U8},
|
||||
Array, DecodeFutureTyped, MemoryExt, ViewExt,
|
||||
},
|
||||
Vm,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PrfFunction {
|
||||
// The label, e.g. "master secret".
|
||||
label: &'static [u8],
|
||||
// The start seed and the label, e.g. client_random + server_random + "master_secret".
|
||||
start_seed_label: Option<Vec<u8>>,
|
||||
iterations: usize,
|
||||
state: PrfState,
|
||||
a: VecDeque<AHash>,
|
||||
p: VecDeque<PHash>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum PrfState {
|
||||
InnerPartial {
|
||||
inner_partial: DecodeFutureTyped<BitVec, [u32; 8]>,
|
||||
},
|
||||
ComputeA {
|
||||
iter: usize,
|
||||
inner_partial: [u32; 8],
|
||||
msg: Vec<u8>,
|
||||
},
|
||||
ComputeP {
|
||||
iter: usize,
|
||||
inner_partial: [u32; 8],
|
||||
a_output: DecodeFutureTyped<BitVec, [u8; 32]>,
|
||||
},
|
||||
FinishLastP,
|
||||
Done,
|
||||
}
|
||||
|
||||
impl PrfFunction {
|
||||
const MS_LABEL: &[u8] = b"master secret";
|
||||
const KEY_LABEL: &[u8] = b"key expansion";
|
||||
const CF_LABEL: &[u8] = b"client finished";
|
||||
const SF_LABEL: &[u8] = b"server finished";
|
||||
|
||||
pub(crate) fn alloc_master_secret(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_key_expansion(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_client_finished(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12)
|
||||
}
|
||||
|
||||
pub(crate) fn alloc_server_finished(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
) -> Result<Self, PrfError> {
|
||||
Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12)
|
||||
}
|
||||
|
||||
pub(crate) fn wants_flush(&self) -> bool {
|
||||
!matches!(self.state, PrfState::Done) && self.start_seed_label.is_some()
|
||||
}
|
||||
|
||||
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
|
||||
match &mut self.state {
|
||||
PrfState::InnerPartial { inner_partial } => {
|
||||
let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
self.state = PrfState::ComputeA {
|
||||
iter: 1,
|
||||
inner_partial,
|
||||
msg: self
|
||||
.start_seed_label
|
||||
.clone()
|
||||
.expect("Start seed should have been set"),
|
||||
};
|
||||
self.flush(vm)?;
|
||||
}
|
||||
PrfState::ComputeA {
|
||||
iter,
|
||||
inner_partial,
|
||||
msg,
|
||||
} => {
|
||||
let a = self.a.pop_front().expect("Prf AHash should be present");
|
||||
assign_inner_local(vm, a.inner_local, *inner_partial, msg)?;
|
||||
|
||||
self.state = PrfState::ComputeP {
|
||||
iter: *iter,
|
||||
inner_partial: *inner_partial,
|
||||
a_output: a.output,
|
||||
};
|
||||
}
|
||||
PrfState::ComputeP {
|
||||
iter,
|
||||
inner_partial,
|
||||
a_output,
|
||||
} => {
|
||||
let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else {
|
||||
return Ok(());
|
||||
};
|
||||
let p = self.p.pop_front().expect("Prf PHash should be present");
|
||||
|
||||
let mut msg = output.to_vec();
|
||||
msg.extend_from_slice(
|
||||
self.start_seed_label
|
||||
.as_ref()
|
||||
.expect("Start seed should have been set"),
|
||||
);
|
||||
|
||||
assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?;
|
||||
|
||||
if *iter == self.iterations {
|
||||
self.state = PrfState::FinishLastP;
|
||||
} else {
|
||||
self.state = PrfState::ComputeA {
|
||||
iter: *iter + 1,
|
||||
inner_partial: *inner_partial,
|
||||
msg: output.to_vec(),
|
||||
}
|
||||
};
|
||||
}
|
||||
PrfState::FinishLastP => self.state = PrfState::Done,
|
||||
_ => (),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn set_start_seed(&mut self, seed: Vec<u8>) {
|
||||
let mut start_seed_label = self.label.to_vec();
|
||||
start_seed_label.extend_from_slice(&seed);
|
||||
|
||||
self.start_seed_label = Some(start_seed_label);
|
||||
}
|
||||
|
||||
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
|
||||
self.p.iter().map(|p| p.output).collect()
|
||||
}
|
||||
|
||||
fn alloc(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
label: &'static [u8],
|
||||
outer_partial: Sha256,
|
||||
inner_partial: Sha256,
|
||||
len: usize,
|
||||
) -> Result<Self, PrfError> {
|
||||
assert!(len > 0, "cannot compute 0 bytes for prf");
|
||||
|
||||
let iterations = len.div_ceil(32);
|
||||
|
||||
let (inner_partial, _) = inner_partial
|
||||
.state()
|
||||
.expect("state should be set for inner_partial");
|
||||
let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?;
|
||||
|
||||
let mut prf = Self {
|
||||
label,
|
||||
start_seed_label: None,
|
||||
iterations,
|
||||
state: PrfState::InnerPartial { inner_partial },
|
||||
a: VecDeque::new(),
|
||||
p: VecDeque::new(),
|
||||
};
|
||||
|
||||
for _ in 0..iterations {
|
||||
// setup A[i]
|
||||
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
|
||||
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
|
||||
|
||||
let output = vm.decode(output).map_err(PrfError::vm)?;
|
||||
let a_hash = AHash {
|
||||
inner_local,
|
||||
output,
|
||||
};
|
||||
|
||||
prf.a.push_front(a_hash);
|
||||
|
||||
// setup P[i]
|
||||
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
|
||||
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
|
||||
let p_hash = PHash {
|
||||
inner_local,
|
||||
output,
|
||||
};
|
||||
prf.p.push_front(p_hash);
|
||||
}
|
||||
|
||||
Ok(prf)
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_inner_local(
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
inner_local: Array<U8, 32>,
|
||||
inner_partial: [u32; 8],
|
||||
msg: &[u8],
|
||||
) -> Result<(), PrfError> {
|
||||
let inner_local_value = sha256(inner_partial, 64, msg);
|
||||
|
||||
vm.mark_public(inner_local).map_err(PrfError::vm)?;
|
||||
vm.assign(inner_local, state_to_bytes(inner_local_value))
|
||||
.map_err(PrfError::vm)?;
|
||||
vm.commit(inner_local).map_err(PrfError::vm)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Like PHash but stores the output as the decoding future because in the
|
||||
/// reduced Prf we need to decode this output.
|
||||
#[derive(Debug)]
|
||||
struct AHash {
|
||||
inner_local: Array<U8, 32>,
|
||||
output: DecodeFutureTyped<BitVec, [u8; 32]>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
struct PHash {
|
||||
inner_local: Array<U8, 32>,
|
||||
output: Array<U8, 32>,
|
||||
}
|
||||
103
crates/components/hmac-sha256/src/prf/state.rs
Normal file
103
crates/components/hmac-sha256/src/prf/state.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use crate::{
|
||||
prf::{function::Prf, merge_outputs},
|
||||
PrfError, PrfOutput, SessionKeys,
|
||||
};
|
||||
use mpz_vm_core::{
|
||||
memory::{
|
||||
binary::{Binary, U8},
|
||||
Array, FromRaw, ToRaw,
|
||||
},
|
||||
Vm,
|
||||
};
|
||||
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum State {
|
||||
Initialized,
|
||||
SessionKeys {
|
||||
client_random: Option<[u8; 32]>,
|
||||
master_secret: Prf,
|
||||
key_expansion: Prf,
|
||||
client_finished: Prf,
|
||||
server_finished: Prf,
|
||||
},
|
||||
ClientFinished {
|
||||
client_finished: Prf,
|
||||
server_finished: Prf,
|
||||
},
|
||||
ServerFinished {
|
||||
server_finished: Prf,
|
||||
},
|
||||
Complete,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub(crate) fn take(&mut self) -> State {
|
||||
std::mem::replace(self, State::Error)
|
||||
}
|
||||
|
||||
pub(crate) fn prf_output(&self, vm: &mut dyn Vm<Binary>) -> Result<PrfOutput, PrfError> {
|
||||
let State::SessionKeys {
|
||||
key_expansion,
|
||||
client_finished,
|
||||
server_finished,
|
||||
..
|
||||
} = self
|
||||
else {
|
||||
return Err(PrfError::state(
|
||||
"Prf output can only be computed while in \"SessionKeys\" state",
|
||||
));
|
||||
};
|
||||
|
||||
let keys = get_session_keys(key_expansion.output(), vm)?;
|
||||
let cf_vd = get_client_finished_vd(client_finished.output(), vm)?;
|
||||
let sf_vd = get_server_finished_vd(server_finished.output(), vm)?;
|
||||
|
||||
let output = PrfOutput { keys, cf_vd, sf_vd };
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_session_keys(
|
||||
output: Vec<Array<U8, 32>>,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<SessionKeys, PrfError> {
|
||||
let mut keys = merge_outputs(vm, output, 40)?;
|
||||
debug_assert!(keys.len() == 40, "session keys len should be 40");
|
||||
|
||||
let server_iv = Array::<U8, 4>::try_from(keys.split_off(36)).unwrap();
|
||||
let client_iv = Array::<U8, 4>::try_from(keys.split_off(32)).unwrap();
|
||||
let server_write_key = Array::<U8, 16>::try_from(keys.split_off(16)).unwrap();
|
||||
let client_write_key = Array::<U8, 16>::try_from(keys).unwrap();
|
||||
|
||||
let session_keys = SessionKeys {
|
||||
client_write_key,
|
||||
server_write_key,
|
||||
client_iv,
|
||||
server_iv,
|
||||
};
|
||||
|
||||
Ok(session_keys)
|
||||
}
|
||||
|
||||
fn get_client_finished_vd(
|
||||
output: Vec<Array<U8, 32>>,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<Array<U8, 12>, PrfError> {
|
||||
let cf_vd = merge_outputs(vm, output, 12)?;
|
||||
let cf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(cf_vd.to_raw());
|
||||
|
||||
Ok(cf_vd)
|
||||
}
|
||||
|
||||
fn get_server_finished_vd(
|
||||
output: Vec<Array<U8, 32>>,
|
||||
vm: &mut dyn Vm<Binary>,
|
||||
) -> Result<Array<U8, 12>, PrfError> {
|
||||
let sf_vd = merge_outputs(vm, output, 12)?;
|
||||
let sf_vd = <Array<U8, 12> as FromRaw<Binary>>::from_raw(sf_vd.to_raw());
|
||||
|
||||
Ok(sf_vd)
|
||||
}
|
||||
261
crates/components/hmac-sha256/src/test_utils.rs
Normal file
261
crates/components/hmac-sha256/src/test_utils.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use crate::{sha256, state_to_bytes};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
|
||||
use mpz_vm_core::memory::correlated::Delta;
|
||||
use rand::{rngs::StdRng, Rng, SeedableRng};
|
||||
|
||||
pub(crate) const SHA256_IV: [u32; 8] = [
|
||||
0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
|
||||
];
|
||||
|
||||
pub(crate) fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
}
|
||||
|
||||
pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] {
|
||||
let mut label_start_seed = b"master secret".to_vec();
|
||||
label_start_seed.extend_from_slice(&client_random);
|
||||
label_start_seed.extend_from_slice(&server_random);
|
||||
|
||||
let ms = phash(pms.to_vec(), &label_start_seed, 2)[..48].to_vec();
|
||||
|
||||
ms.try_into().unwrap()
|
||||
}
|
||||
|
||||
pub(crate) fn prf_keys(
|
||||
ms: [u8; 48],
|
||||
client_random: [u8; 32],
|
||||
server_random: [u8; 32],
|
||||
) -> [Vec<u8>; 4] {
|
||||
let mut label_start_seed = b"key expansion".to_vec();
|
||||
label_start_seed.extend_from_slice(&server_random);
|
||||
label_start_seed.extend_from_slice(&client_random);
|
||||
|
||||
let mut session_keys = phash(ms.to_vec(), &label_start_seed, 2)[..40].to_vec();
|
||||
|
||||
let server_iv = session_keys.split_off(36);
|
||||
let client_iv = session_keys.split_off(32);
|
||||
let server_write_key = session_keys.split_off(16);
|
||||
let client_write_key = session_keys;
|
||||
|
||||
[client_write_key, server_write_key, client_iv, server_iv]
|
||||
}
|
||||
|
||||
pub(crate) fn prf_cf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec<u8> {
|
||||
let mut label_start_seed = b"client finished".to_vec();
|
||||
label_start_seed.extend_from_slice(&hanshake_hash);
|
||||
|
||||
phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec()
|
||||
}
|
||||
|
||||
pub(crate) fn prf_sf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec<u8> {
|
||||
let mut label_start_seed = b"server finished".to_vec();
|
||||
label_start_seed.extend_from_slice(&hanshake_hash);
|
||||
|
||||
phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec()
|
||||
}
|
||||
|
||||
pub(crate) fn phash(key: Vec<u8>, seed: &[u8], iterations: usize) -> Vec<u8> {
|
||||
// A() is defined as:
|
||||
//
|
||||
// A(0) = seed
|
||||
// A(i) = HMAC_hash(secret, A(i-1))
|
||||
let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1);
|
||||
a_cache.push(seed.to_vec());
|
||||
|
||||
for i in 0..iterations {
|
||||
let a_i = hmac_sha256(key.clone(), &a_cache[i]);
|
||||
a_cache.push(a_i.to_vec());
|
||||
}
|
||||
|
||||
// HMAC_hash(secret, A(i) + seed)
|
||||
let mut output: Vec<_> = Vec::with_capacity(iterations * 32);
|
||||
for i in 0..iterations {
|
||||
let mut a_i_seed = a_cache[i + 1].clone();
|
||||
a_i_seed.extend_from_slice(seed);
|
||||
|
||||
let hash = hmac_sha256(key.clone(), &a_i_seed);
|
||||
output.extend_from_slice(&hash);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn hmac_sha256(key: Vec<u8>, msg: &[u8]) -> [u8; 32] {
|
||||
let outer_partial = compute_outer_partial(key.clone());
|
||||
let inner_local = compute_inner_local(key, msg);
|
||||
|
||||
let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local));
|
||||
state_to_bytes(hmac)
|
||||
}
|
||||
|
||||
pub(crate) fn compute_outer_partial(mut key: Vec<u8>) -> [u32; 8] {
|
||||
assert!(key.len() <= 64);
|
||||
|
||||
key.resize(64, 0_u8);
|
||||
let key_padded: [u8; 64] = key
|
||||
.into_iter()
|
||||
.map(|b| b ^ 0x5c)
|
||||
.collect::<Vec<u8>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
compress_256(SHA256_IV, &key_padded)
|
||||
}
|
||||
|
||||
pub(crate) fn compute_inner_local(mut key: Vec<u8>, msg: &[u8]) -> [u32; 8] {
|
||||
assert!(key.len() <= 64);
|
||||
|
||||
key.resize(64, 0_u8);
|
||||
let key_padded: [u8; 64] = key
|
||||
.into_iter()
|
||||
.map(|b| b ^ 0x36)
|
||||
.collect::<Vec<u8>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
|
||||
let state = compress_256(SHA256_IV, &key_padded);
|
||||
sha256(state, 64, msg)
|
||||
}
|
||||
|
||||
pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] {
|
||||
use sha2::{
|
||||
compress256,
|
||||
digest::{
|
||||
block_buffer::{BlockBuffer, Eager},
|
||||
generic_array::typenum::U64,
|
||||
},
|
||||
};
|
||||
|
||||
let mut buffer = BlockBuffer::<U64, Eager>::default();
|
||||
buffer.digest_blocks(msg, |b| compress256(&mut state, b));
|
||||
state
|
||||
}
|
||||
|
||||
// Borrowed from Rustls for testing
|
||||
// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs
|
||||
mod ring_prf {
|
||||
use ring::{hmac, hmac::HMAC_SHA256};
|
||||
|
||||
fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag {
|
||||
let mut ctx = hmac::Context::with_key(key);
|
||||
ctx.update(a);
|
||||
ctx.update(b);
|
||||
ctx.sign()
|
||||
}
|
||||
|
||||
fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) {
|
||||
let hmac_key = hmac::Key::new(HMAC_SHA256, secret);
|
||||
|
||||
// A(1)
|
||||
let mut current_a = hmac::sign(&hmac_key, seed);
|
||||
let chunk_size = HMAC_SHA256.digest_algorithm().output_len();
|
||||
for chunk in out.chunks_mut(chunk_size) {
|
||||
// P_hash[i] = HMAC_hash(secret, A(i) + seed)
|
||||
let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed);
|
||||
chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]);
|
||||
|
||||
// A(i+1) = HMAC_hash(secret, A(i))
|
||||
current_a = hmac::sign(&hmac_key, current_a.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
fn concat(a: &[u8], b: &[u8]) -> Vec<u8> {
|
||||
let mut ret = Vec::new();
|
||||
ret.extend_from_slice(a);
|
||||
ret.extend_from_slice(b);
|
||||
ret
|
||||
}
|
||||
|
||||
pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) {
|
||||
let joined_seed = concat(label, seed);
|
||||
p(out, secret, &joined_seed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prf_reference_ms() {
|
||||
use ring_prf::prf as prf_ref;
|
||||
|
||||
let mut rng = StdRng::from_seed([1; 32]);
|
||||
|
||||
let pms: [u8; 32] = rng.random();
|
||||
let label: &[u8] = b"master secret";
|
||||
let client_random: [u8; 32] = rng.random();
|
||||
let server_random: [u8; 32] = rng.random();
|
||||
let mut seed = Vec::from(client_random);
|
||||
seed.extend_from_slice(&server_random);
|
||||
|
||||
let ms = prf_ms(pms, client_random, server_random);
|
||||
|
||||
let mut expected_ms: [u8; 48] = [0; 48];
|
||||
prf_ref(&mut expected_ms, &pms, label, &seed);
|
||||
|
||||
assert_eq!(ms, expected_ms);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prf_reference_ke() {
|
||||
use ring_prf::prf as prf_ref;
|
||||
|
||||
let mut rng = StdRng::from_seed([2; 32]);
|
||||
|
||||
let ms: [u8; 48] = rng.random();
|
||||
let label: &[u8] = b"key expansion";
|
||||
let client_random: [u8; 32] = rng.random();
|
||||
let server_random: [u8; 32] = rng.random();
|
||||
let mut seed = Vec::from(server_random);
|
||||
seed.extend_from_slice(&client_random);
|
||||
|
||||
let keys = prf_keys(ms, client_random, server_random);
|
||||
let keys: Vec<u8> = keys.into_iter().flatten().collect();
|
||||
|
||||
let mut expected_keys: [u8; 40] = [0; 40];
|
||||
prf_ref(&mut expected_keys, &ms, label, &seed);
|
||||
|
||||
assert_eq!(keys, expected_keys);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prf_reference_cf() {
|
||||
use ring_prf::prf as prf_ref;
|
||||
|
||||
let mut rng = StdRng::from_seed([3; 32]);
|
||||
|
||||
let ms: [u8; 48] = rng.random();
|
||||
let label: &[u8] = b"client finished";
|
||||
let handshake_hash: [u8; 32] = rng.random();
|
||||
|
||||
let cf_vd = prf_cf_vd(ms, handshake_hash);
|
||||
|
||||
let mut expected_cf_vd: [u8; 12] = [0; 12];
|
||||
prf_ref(&mut expected_cf_vd, &ms, label, &handshake_hash);
|
||||
|
||||
assert_eq!(cf_vd, expected_cf_vd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prf_reference_sf() {
|
||||
use ring_prf::prf as prf_ref;
|
||||
|
||||
let mut rng = StdRng::from_seed([4; 32]);
|
||||
|
||||
let ms: [u8; 48] = rng.random();
|
||||
let label: &[u8] = b"server finished";
|
||||
let handshake_hash: [u8; 32] = rng.random();
|
||||
|
||||
let sf_vd = prf_sf_vd(ms, handshake_hash);
|
||||
|
||||
let mut expected_sf_vd: [u8; 12] = [0; 12];
|
||||
prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash);
|
||||
|
||||
assert_eq!(sf_vd, expected_sf_vd);
|
||||
}
|
||||
@@ -5,9 +5,12 @@ description = "Implementation of the 3-party key-exchange protocol"
|
||||
keywords = ["tls", "mpc", "2pc", "pms", "key-exchange"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
name = "key_exchange"
|
||||
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
//! This module provides the circuits used in the key exchange protocol.
|
||||
|
||||
use mpz_circuits::{circuits::big_num::nbyte_add_mod_trace, Circuit, CircuitBuilder};
|
||||
use mpz_circuits::{ops::add_mod, Circuit, CircuitBuilder, Feed, Node};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// NIST P-256 prime big-endian.
|
||||
static P: [u8; 32] = [
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
];
|
||||
|
||||
/// Circuit for combining additive shares of the PMS, twice
|
||||
///
|
||||
/// # Inputs
|
||||
@@ -17,26 +11,65 @@ static P: [u8; 32] = [
|
||||
/// 1. PMS_SHARE_B0: 32 bytes PMS Additive Share
|
||||
/// 2. PMS_SHARE_A1: 32 bytes PMS Additive Share
|
||||
/// 3. PMS_SHARE_B1: 32 bytes PMS Additive Share
|
||||
/// 4. MODULUS: 32 bytes field modulus
|
||||
///
|
||||
/// # Outputs
|
||||
/// 0. PMS_0: Pre-master Secret = PMS_SHARE_A0 + PMS_SHARE_B0
|
||||
/// 1. PMS_1: Pre-master Secret = PMS_SHARE_A1 + PMS_SHARE_B1
|
||||
/// 2. EQ: Equality check of PMS_0 and PMS_1
|
||||
pub(crate) fn build_pms_circuit() -> Arc<Circuit> {
|
||||
let builder = CircuitBuilder::new();
|
||||
let share_a0 = builder.add_array_input::<u8, 32>();
|
||||
let share_b0 = builder.add_array_input::<u8, 32>();
|
||||
let share_a1 = builder.add_array_input::<u8, 32>();
|
||||
let share_b1 = builder.add_array_input::<u8, 32>();
|
||||
let mut builder = CircuitBuilder::new();
|
||||
|
||||
let pms_0 = nbyte_add_mod_trace(builder.state(), share_a0, share_b0, P);
|
||||
let pms_1 = nbyte_add_mod_trace(builder.state(), share_a1, share_b1, P);
|
||||
let share_a0 = (0..32 * 8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
let share_b0 = (0..32 * 8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
let share_a1 = (0..32 * 8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
let share_b1 = (0..32 * 8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
|
||||
let eq: [_; 32] = std::array::from_fn(|i| pms_0[i] ^ pms_1[i]);
|
||||
let modulus = (0..32 * 8).map(|_| builder.add_input()).collect::<Vec<_>>();
|
||||
|
||||
builder.add_output(pms_0);
|
||||
builder.add_output(pms_1);
|
||||
builder.add_output(eq);
|
||||
/// assumes input is provided as big endian
|
||||
fn to_little_endian(input: &[Node<Feed>]) -> Vec<Node<Feed>> {
|
||||
let mut le_lsb0_output = vec![];
|
||||
for node in input.chunks_exact(8).rev() {
|
||||
for &bit in node.iter() {
|
||||
le_lsb0_output.push(bit);
|
||||
}
|
||||
}
|
||||
le_lsb0_output
|
||||
}
|
||||
|
||||
let pms_0 = add_mod(
|
||||
&mut builder,
|
||||
&to_little_endian(&share_a0),
|
||||
&to_little_endian(&share_b0),
|
||||
&to_little_endian(&modulus),
|
||||
);
|
||||
|
||||
// return output as big endian
|
||||
for node in pms_0.chunks_exact(8).rev() {
|
||||
for &bit in node.iter() {
|
||||
builder.add_output(bit);
|
||||
}
|
||||
}
|
||||
|
||||
let pms_1 = add_mod(
|
||||
&mut builder,
|
||||
&to_little_endian(&share_a1),
|
||||
&to_little_endian(&share_b1),
|
||||
&to_little_endian(&modulus),
|
||||
);
|
||||
|
||||
// return output as big endian
|
||||
for node in pms_1.chunks_exact(8).rev() {
|
||||
for &bit in node.iter() {
|
||||
builder.add_output(bit);
|
||||
}
|
||||
}
|
||||
|
||||
for (a, b) in pms_0.into_iter().zip(pms_1) {
|
||||
let out = builder.add_xor_gate(a, b);
|
||||
builder.add_output(out);
|
||||
}
|
||||
|
||||
Arc::new(builder.build().expect("pms circuit is valid"))
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ use serio::{sink::SinkExt, stream::IoStreamExt};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::instrument;
|
||||
|
||||
use mpz_common::{scoped_futures::ScopedFutureExt, Context, Flush};
|
||||
use mpz_common::{Context, Flush};
|
||||
use mpz_core::bitvec::BitVec;
|
||||
use mpz_fields::{p256::P256, Field};
|
||||
use mpz_memory_core::{
|
||||
@@ -24,6 +24,12 @@ use crate::{
|
||||
KeyExchangeError, Pms, Role,
|
||||
};
|
||||
|
||||
/// NIST P-256 prime big-endian.
|
||||
static P: [u8; 32] = [
|
||||
0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
||||
];
|
||||
|
||||
#[derive(Debug)]
|
||||
enum State {
|
||||
Initialized,
|
||||
@@ -164,12 +170,18 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let p_constant: Array<U8, 32> = vm.alloc().map_err(KeyExchangeError::vm)?;
|
||||
vm.mark_public(p_constant).map_err(KeyExchangeError::vm)?;
|
||||
vm.assign(p_constant, P).map_err(KeyExchangeError::vm)?;
|
||||
vm.commit(p_constant).map_err(KeyExchangeError::vm)?;
|
||||
|
||||
let pms_circuit = build_pms_circuit();
|
||||
let pms_call = CallBuilder::new(pms_circuit)
|
||||
.arg(share_a0)
|
||||
.arg(share_b0)
|
||||
.arg(share_a1)
|
||||
.arg(share_b1)
|
||||
.arg(p_constant)
|
||||
.build()
|
||||
.map_err(KeyExchangeError::vm)?;
|
||||
|
||||
@@ -240,35 +252,26 @@ where
|
||||
|
||||
let (follower_key, _, _) = ctx
|
||||
.try_join3(
|
||||
move |ctx| {
|
||||
async move {
|
||||
Ok(match role {
|
||||
Role::Leader => ctx.io_mut().expect_next().await?,
|
||||
Role::Follower => {
|
||||
ctx.io_mut().send(public_key).await?;
|
||||
public_key
|
||||
}
|
||||
})
|
||||
}
|
||||
.scope_boxed()
|
||||
async move |ctx| {
|
||||
Ok(match role {
|
||||
Role::Leader => ctx.io_mut().expect_next().await?,
|
||||
Role::Follower => {
|
||||
ctx.io_mut().send(public_key).await?;
|
||||
public_key
|
||||
}
|
||||
})
|
||||
},
|
||||
move |ctx| {
|
||||
async move {
|
||||
converter_0
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)
|
||||
}
|
||||
.scope_boxed()
|
||||
async move |ctx| {
|
||||
converter_0
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)
|
||||
},
|
||||
move |ctx| {
|
||||
async move {
|
||||
converter_1
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)
|
||||
}
|
||||
.scope_boxed()
|
||||
async move |ctx| {
|
||||
converter_1
|
||||
.flush(ctx)
|
||||
.await
|
||||
.map_err(KeyExchangeError::share_conversion)
|
||||
},
|
||||
)
|
||||
.await??;
|
||||
@@ -297,7 +300,7 @@ where
|
||||
} = self.state.take()
|
||||
else {
|
||||
return Err(KeyExchangeError::state(
|
||||
"can not compute shares before performing setup",
|
||||
"cannot compute shares before performing setup",
|
||||
));
|
||||
};
|
||||
|
||||
@@ -437,13 +440,11 @@ where
|
||||
let mut converter_1 = converter_1.try_lock_owned().unwrap();
|
||||
let (pms_share_0, pms_share_1) = ctx
|
||||
.try_join(
|
||||
move |ctx| {
|
||||
async move { derive_x_coord_share(ctx, role, &mut *converter_0, encoded_point).await }
|
||||
.scope_boxed()
|
||||
async move |ctx| {
|
||||
derive_x_coord_share(ctx, role, &mut *converter_0, encoded_point).await
|
||||
},
|
||||
move |ctx| {
|
||||
async move { derive_x_coord_share(ctx, role, &mut *converter_1, encoded_point).await }
|
||||
.scope_boxed()
|
||||
async move |ctx| {
|
||||
derive_x_coord_share(ctx, role, &mut *converter_1, encoded_point).await
|
||||
},
|
||||
)
|
||||
.await??;
|
||||
@@ -458,7 +459,7 @@ mod tests {
|
||||
use mpz_common::context::test_st_context;
|
||||
use mpz_core::Block;
|
||||
use mpz_fields::UniformRand;
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_memory_core::correlated::Delta;
|
||||
use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender};
|
||||
use mpz_share_conversion::ideal::{
|
||||
@@ -487,7 +488,10 @@ mod tests {
|
||||
|
||||
let leader_private_key = SecretKey::random(&mut rng);
|
||||
let follower_private_key = SecretKey::random(&mut rng);
|
||||
let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng));
|
||||
|
||||
let server_secret_key = &NonZeroScalar::random(&mut rng);
|
||||
let server_public_key = PublicKey::from_secret_scalar(server_secret_key);
|
||||
|
||||
let expected_client_public_key = PublicKey::from_affine(
|
||||
(leader_private_key.public_key().to_projective()
|
||||
+ follower_private_key.public_key().to_projective())
|
||||
@@ -540,7 +544,12 @@ mod tests {
|
||||
}
|
||||
);
|
||||
|
||||
let expected_ecdh_x =
|
||||
p256::ecdh::diffie_hellman(server_secret_key, client_public_key.as_affine());
|
||||
let expected_ecdh_x = expected_ecdh_x.raw_secret_bytes().to_vec();
|
||||
|
||||
assert_eq!(leader_pms, follower_pms);
|
||||
assert_eq!(leader_pms.to_vec(), expected_ecdh_x);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -614,13 +623,14 @@ mod tests {
|
||||
#[case::malicious_follower(Malicious::Follower)]
|
||||
#[tokio::test]
|
||||
async fn test_malicious_key_exchange(#[case] malicious: Malicious) {
|
||||
let mut rng = StdRng::seed_from_u64(0).compat();
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let (mut ctx_a, mut ctx_b) = test_st_context(8);
|
||||
let (mut gen, mut ev) = mock_vm();
|
||||
|
||||
let leader_private_key = SecretKey::random(&mut rng);
|
||||
let follower_private_key = SecretKey::random(&mut rng);
|
||||
let server_public_key = PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng));
|
||||
let leader_private_key = SecretKey::random(&mut rng.compat_by_ref());
|
||||
let follower_private_key = SecretKey::random(&mut rng.compat_by_ref());
|
||||
let server_public_key =
|
||||
PublicKey::from_secret_scalar(&NonZeroScalar::random(&mut rng.compat_by_ref()));
|
||||
let expected_client_public_key = PublicKey::from_affine(
|
||||
(leader_private_key.public_key().to_projective()
|
||||
+ follower_private_key.public_key().to_projective())
|
||||
@@ -705,6 +715,12 @@ mod tests {
|
||||
let (res_gen, res_ev) = tokio::join!(
|
||||
async move {
|
||||
let mut vm = gen;
|
||||
|
||||
let p_constant: Array<U8, 32> = vm.alloc().unwrap();
|
||||
vm.mark_public(p_constant).unwrap();
|
||||
vm.assign(p_constant, P).unwrap();
|
||||
vm.commit(p_constant).unwrap();
|
||||
|
||||
let share_a0: Array<U8, 32> = vm.alloc().unwrap();
|
||||
vm.mark_private(share_a0).unwrap();
|
||||
|
||||
@@ -723,6 +739,7 @@ mod tests {
|
||||
.arg(share_b0)
|
||||
.arg(share_a1)
|
||||
.arg(share_b1)
|
||||
.arg(p_constant)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
@@ -747,6 +764,11 @@ mod tests {
|
||||
},
|
||||
async {
|
||||
let mut vm = ev;
|
||||
let p_constant: Array<U8, 32> = vm.alloc().unwrap();
|
||||
vm.mark_public(p_constant).unwrap();
|
||||
vm.assign(p_constant, P).unwrap();
|
||||
vm.commit(p_constant).unwrap();
|
||||
|
||||
let share_a0: Array<U8, 32> = vm.alloc().unwrap();
|
||||
vm.mark_blind(share_a0).unwrap();
|
||||
|
||||
@@ -765,6 +787,7 @@ mod tests {
|
||||
.arg(share_b0)
|
||||
.arg(share_a1)
|
||||
.arg(share_b1)
|
||||
.arg(p_constant)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
@@ -812,13 +835,13 @@ mod tests {
|
||||
(leader, follower)
|
||||
}
|
||||
|
||||
fn mock_vm() -> (Generator<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
|
||||
let mut rng = StdRng::seed_from_u64(0).compat();
|
||||
fn mock_vm() -> (Garbler<IdealCOTSender>, Evaluator<IdealCOTReceiver>) {
|
||||
let mut rng = StdRng::seed_from_u64(0);
|
||||
let delta = Delta::random(&mut rng);
|
||||
|
||||
let (cot_send, cot_recv) = ideal_cot(delta.into_inner());
|
||||
|
||||
let gen = Generator::new(cot_send, [0u8; 16], delta);
|
||||
let gen = Garbler::new(cot_send, [0u8; 16], delta);
|
||||
let ev = Evaluator::new(cot_recv);
|
||||
|
||||
(gen, ev)
|
||||
|
||||
@@ -26,7 +26,7 @@ pub fn create_mock_key_exchange_pair() -> (MockKeyExchange, MockKeyExchange) {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Generator};
|
||||
use mpz_garble::protocol::semihonest::{Evaluator, Garbler};
|
||||
use mpz_ot::ideal::cot::{IdealCOTReceiver, IdealCOTSender};
|
||||
|
||||
use super::*;
|
||||
@@ -40,7 +40,7 @@ mod tests {
|
||||
|
||||
is_key_exchange::<
|
||||
MpcKeyExchange<IdealShareConvertSender<P256>, IdealShareConvertReceiver<P256>>,
|
||||
Generator<IdealCOTSender>,
|
||||
Garbler<IdealCOTSender>,
|
||||
>(leader);
|
||||
|
||||
is_key_exchange::<
|
||||
|
||||
@@ -5,9 +5,12 @@ description = "Core types for TLSNotary"
|
||||
keywords = ["tls", "mpc", "2pc", "types"]
|
||||
categories = ["cryptography"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version = "0.1.0-alpha.9"
|
||||
version = "0.1.0-alpha.12"
|
||||
edition = "2021"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[features]
|
||||
default = []
|
||||
fixtures = ["dep:hex", "dep:tlsn-data-fixtures"]
|
||||
|
||||
@@ -13,9 +13,26 @@
|
||||
//! The body contains the fields of the attestation. These fields include data
|
||||
//! which can be used to verify aspects of a TLS connection, such as the
|
||||
//! server's identity, and facts about the transcript.
|
||||
//!
|
||||
//! # Extensions
|
||||
//!
|
||||
//! An attestation may be extended using [`Extension`] fields included in the
|
||||
//! body. Extensions (currently) have no canonical semantics, but may be used to
|
||||
//! implement application specific functionality.
|
||||
//!
|
||||
//! A Prover may [append
|
||||
//! extensions](crate::request::RequestConfigBuilder::extension)
|
||||
//! to their attestation request, provided that the Notary supports them
|
||||
//! (disallowed by default). A Notary may also be configured to
|
||||
//! [validate](crate::attestation::AttestationConfigBuilder::extension_validator)
|
||||
//! any extensions requested by a Prover using custom application logic.
|
||||
//! Additionally, a Notary may
|
||||
//! [include](crate::attestation::AttestationBuilder::extension)
|
||||
//! their own extensions.
|
||||
|
||||
mod builder;
|
||||
mod config;
|
||||
mod extension;
|
||||
mod proof;
|
||||
|
||||
use std::fmt;
|
||||
@@ -26,16 +43,16 @@ use serde::{Deserialize, Serialize};
|
||||
use crate::{
|
||||
connection::{ConnectionInfo, ServerCertCommitment, ServerEphemKey},
|
||||
hash::{impl_domain_separator, Hash, HashAlgorithm, HashAlgorithmExt, TypedHash},
|
||||
index::Index,
|
||||
merkle::MerkleTree,
|
||||
presentation::PresentationBuilder,
|
||||
signing::{Signature, VerifyingKey},
|
||||
transcript::{encoding::EncodingCommitment, hash::PlaintextHash},
|
||||
transcript::TranscriptCommitment,
|
||||
CryptoProvider,
|
||||
};
|
||||
|
||||
pub use builder::{AttestationBuilder, AttestationBuilderError};
|
||||
pub use config::{AttestationConfig, AttestationConfigBuilder, AttestationConfigError};
|
||||
pub use extension::{Extension, InvalidExtension};
|
||||
pub use proof::{AttestationError, AttestationProof};
|
||||
|
||||
/// Current version of attestations.
|
||||
@@ -133,11 +150,16 @@ pub struct Body {
|
||||
connection_info: Field<ConnectionInfo>,
|
||||
server_ephemeral_key: Field<ServerEphemKey>,
|
||||
cert_commitment: Field<ServerCertCommitment>,
|
||||
encoding_commitment: Option<Field<EncodingCommitment>>,
|
||||
plaintext_hashes: Index<Field<PlaintextHash>>,
|
||||
extensions: Vec<Field<Extension>>,
|
||||
transcript_commitments: Vec<Field<TranscriptCommitment>>,
|
||||
}
|
||||
|
||||
impl Body {
|
||||
/// Returns an iterator over the extensions.
|
||||
pub fn extensions(&self) -> impl Iterator<Item = &Extension> {
|
||||
self.extensions.iter().map(|field| &field.data)
|
||||
}
|
||||
|
||||
/// Returns the attestation verifying key.
|
||||
pub fn verifying_key(&self) -> &VerifyingKey {
|
||||
&self.verifying_key.data
|
||||
@@ -173,8 +195,8 @@ impl Body {
|
||||
connection_info: conn_info,
|
||||
server_ephemeral_key,
|
||||
cert_commitment,
|
||||
encoding_commitment,
|
||||
plaintext_hashes,
|
||||
extensions,
|
||||
transcript_commitments,
|
||||
} = self;
|
||||
|
||||
let mut fields: Vec<(FieldId, Hash)> = vec![
|
||||
@@ -190,14 +212,11 @@ impl Body {
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(encoding_commitment) = encoding_commitment {
|
||||
fields.push((
|
||||
encoding_commitment.id,
|
||||
hasher.hash_separated(&encoding_commitment.data),
|
||||
));
|
||||
for field in extensions.iter() {
|
||||
fields.push((field.id, hasher.hash_separated(&field.data)));
|
||||
}
|
||||
|
||||
for field in plaintext_hashes.iter() {
|
||||
for field in transcript_commitments.iter() {
|
||||
fields.push((field.id, hasher.hash_separated(&field.data)));
|
||||
}
|
||||
|
||||
@@ -220,14 +239,9 @@ impl Body {
|
||||
&self.cert_commitment.data
|
||||
}
|
||||
|
||||
/// Returns the encoding commitment.
|
||||
pub(crate) fn encoding_commitment(&self) -> Option<&EncodingCommitment> {
|
||||
self.encoding_commitment.as_ref().map(|field| &field.data)
|
||||
}
|
||||
|
||||
/// Returns the plaintext hash commitments.
|
||||
pub(crate) fn plaintext_hashes(&self) -> &Index<Field<PlaintextHash>> {
|
||||
&self.plaintext_hashes
|
||||
/// Returns the transcript commitments.
|
||||
pub(crate) fn transcript_commitments(&self) -> impl Iterator<Item = &TranscriptCommitment> {
|
||||
self.transcript_commitments.iter().map(|field| &field.data)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,32 +4,35 @@ use rand::{rng, Rng};
|
||||
|
||||
use crate::{
|
||||
attestation::{
|
||||
Attestation, AttestationConfig, Body, EncodingCommitment, FieldId, FieldKind, Header,
|
||||
ServerCertCommitment, VERSION,
|
||||
Attestation, AttestationConfig, Body, Extension, FieldId, Header, ServerCertCommitment,
|
||||
VERSION,
|
||||
},
|
||||
connection::{ConnectionInfo, ServerEphemKey},
|
||||
hash::{HashAlgId, TypedHash},
|
||||
hash::HashAlgId,
|
||||
request::Request,
|
||||
serialize::CanonicalSerialize,
|
||||
signing::SignatureAlgId,
|
||||
transcript::encoding::EncoderSecret,
|
||||
transcript::TranscriptCommitment,
|
||||
CryptoProvider,
|
||||
};
|
||||
|
||||
/// Attestation builder state for accepting a request.
|
||||
#[derive(Debug)]
|
||||
pub struct Accept {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Sign {
|
||||
signature_alg: SignatureAlgId,
|
||||
hash_alg: HashAlgId,
|
||||
connection_info: Option<ConnectionInfo>,
|
||||
server_ephemeral_key: Option<ServerEphemKey>,
|
||||
cert_commitment: ServerCertCommitment,
|
||||
encoding_commitment_root: Option<TypedHash>,
|
||||
encoder_secret: Option<EncoderSecret>,
|
||||
extensions: Vec<Extension>,
|
||||
transcript_commitments: Vec<TranscriptCommitment>,
|
||||
}
|
||||
|
||||
/// An attestation builder.
|
||||
#[derive(Debug)]
|
||||
pub struct AttestationBuilder<'a, T = Accept> {
|
||||
config: &'a AttestationConfig,
|
||||
state: T,
|
||||
@@ -55,7 +58,7 @@ impl<'a> AttestationBuilder<'a, Accept> {
|
||||
signature_alg,
|
||||
hash_alg,
|
||||
server_cert_commitment: cert_commitment,
|
||||
encoding_commitment_root,
|
||||
extensions,
|
||||
} = request;
|
||||
|
||||
if !config.supported_signature_algs().contains(&signature_alg) {
|
||||
@@ -72,15 +75,9 @@ impl<'a> AttestationBuilder<'a, Accept> {
|
||||
));
|
||||
}
|
||||
|
||||
if encoding_commitment_root.is_some()
|
||||
&& !config
|
||||
.supported_fields()
|
||||
.contains(&FieldKind::EncodingCommitment)
|
||||
{
|
||||
return Err(AttestationBuilderError::new(
|
||||
ErrorKind::Request,
|
||||
"encoding commitment is not supported",
|
||||
));
|
||||
if let Some(validator) = config.extension_validator() {
|
||||
validator(&extensions)
|
||||
.map_err(|err| AttestationBuilderError::new(ErrorKind::Extension, err))?;
|
||||
}
|
||||
|
||||
Ok(AttestationBuilder {
|
||||
@@ -91,8 +88,8 @@ impl<'a> AttestationBuilder<'a, Accept> {
|
||||
connection_info: None,
|
||||
server_ephemeral_key: None,
|
||||
cert_commitment,
|
||||
encoding_commitment_root,
|
||||
encoder_secret: None,
|
||||
transcript_commitments: Vec::new(),
|
||||
extensions,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -111,9 +108,18 @@ impl AttestationBuilder<'_, Sign> {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the encoder secret.
|
||||
pub fn encoder_secret(&mut self, secret: EncoderSecret) -> &mut Self {
|
||||
self.state.encoder_secret = Some(secret);
|
||||
/// Adds an extension to the attestation.
|
||||
pub fn extension(&mut self, extension: Extension) -> &mut Self {
|
||||
self.state.extensions.push(extension);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the transcript commitments.
|
||||
pub fn transcript_commitments(
|
||||
&mut self,
|
||||
transcript_commitments: Vec<TranscriptCommitment>,
|
||||
) -> &mut Self {
|
||||
self.state.transcript_commitments = transcript_commitments;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -125,8 +131,8 @@ impl AttestationBuilder<'_, Sign> {
|
||||
connection_info,
|
||||
server_ephemeral_key,
|
||||
cert_commitment,
|
||||
encoding_commitment_root,
|
||||
encoder_secret,
|
||||
extensions,
|
||||
transcript_commitments,
|
||||
} = self.state;
|
||||
|
||||
let hasher = provider.hash.get(&hash_alg).map_err(|_| {
|
||||
@@ -144,19 +150,6 @@ impl AttestationBuilder<'_, Sign> {
|
||||
)
|
||||
})?;
|
||||
|
||||
let encoding_commitment = if let Some(root) = encoding_commitment_root {
|
||||
let Some(secret) = encoder_secret else {
|
||||
return Err(AttestationBuilderError::new(
|
||||
ErrorKind::Field,
|
||||
"encoding commitment requested but encoder_secret was not set",
|
||||
));
|
||||
};
|
||||
|
||||
Some(EncodingCommitment { root, secret })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut field_id = FieldId::default();
|
||||
|
||||
let body = Body {
|
||||
@@ -168,8 +161,14 @@ impl AttestationBuilder<'_, Sign> {
|
||||
AttestationBuilderError::new(ErrorKind::Field, "handshake data was not set")
|
||||
})?),
|
||||
cert_commitment: field_id.next(cert_commitment),
|
||||
encoding_commitment: encoding_commitment.map(|commitment| field_id.next(commitment)),
|
||||
plaintext_hashes: Default::default(),
|
||||
extensions: extensions
|
||||
.into_iter()
|
||||
.map(|extension| field_id.next(extension))
|
||||
.collect(),
|
||||
transcript_commitments: transcript_commitments
|
||||
.into_iter()
|
||||
.map(|commitment| field_id.next(commitment))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let header = Header {
|
||||
@@ -203,6 +202,7 @@ enum ErrorKind {
|
||||
Config,
|
||||
Field,
|
||||
Signature,
|
||||
Extension,
|
||||
}
|
||||
|
||||
impl AttestationBuilderError {
|
||||
@@ -229,6 +229,7 @@ impl std::fmt::Display for AttestationBuilderError {
|
||||
ErrorKind::Config => f.write_str("config error")?,
|
||||
ErrorKind::Field => f.write_str("field error")?,
|
||||
ErrorKind::Signature => f.write_str("signature error")?,
|
||||
ErrorKind::Extension => f.write_str("extension error")?,
|
||||
}
|
||||
|
||||
if let Some(source) = &self.source {
|
||||
@@ -246,9 +247,7 @@ mod test {
|
||||
|
||||
use crate::{
|
||||
connection::{HandshakeData, HandshakeDataV1_2},
|
||||
fixtures::{
|
||||
encoder_secret, encoding_provider, request_fixture, ConnectionFixture, RequestFixture,
|
||||
},
|
||||
fixtures::{encoding_provider, request_fixture, ConnectionFixture, RequestFixture},
|
||||
hash::Blake3,
|
||||
transcript::Transcript,
|
||||
};
|
||||
@@ -282,6 +281,7 @@ mod test {
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
@@ -306,6 +306,7 @@ mod test {
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
@@ -321,35 +322,6 @@ mod test {
|
||||
assert!(err.is_request());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_accept_unsupported_encoding_commitment() {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
);
|
||||
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
.supported_signature_algs([SignatureAlgId::SECP256K1])
|
||||
.supported_fields([
|
||||
FieldKind::ConnectionInfo,
|
||||
FieldKind::ServerEphemKey,
|
||||
FieldKind::ServerIdentityCommitment,
|
||||
])
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let err = Attestation::builder(&attestation_config)
|
||||
.accept_request(request)
|
||||
.err()
|
||||
.unwrap();
|
||||
assert!(err.is_request());
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_sign_missing_signer(attestation_config: &AttestationConfig) {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
@@ -360,6 +332,7 @@ mod test {
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection,
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let attestation_builder = Attestation::builder(attestation_config)
|
||||
@@ -369,48 +342,10 @@ mod test {
|
||||
let mut provider = CryptoProvider::default();
|
||||
provider.signer.set_secp256r1(&[42u8; 32]).unwrap();
|
||||
|
||||
let err = attestation_builder.build(&provider).err().unwrap();
|
||||
let err = attestation_builder.build(&provider).unwrap_err();
|
||||
assert!(matches!(err.kind, ErrorKind::Config));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_sign_missing_encoding_seed(
|
||||
attestation_config: &AttestationConfig,
|
||||
crypto_provider: &CryptoProvider,
|
||||
) {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(attestation_config)
|
||||
.accept_request(request)
|
||||
.unwrap();
|
||||
|
||||
let ConnectionFixture {
|
||||
connection_info,
|
||||
server_cert_data,
|
||||
..
|
||||
} = connection;
|
||||
|
||||
let HandshakeData::V1_2(HandshakeDataV1_2 {
|
||||
server_ephemeral_key,
|
||||
..
|
||||
}) = server_cert_data.handshake;
|
||||
|
||||
attestation_builder
|
||||
.connection_info(connection_info)
|
||||
.server_ephemeral_key(server_ephemeral_key);
|
||||
|
||||
let err = attestation_builder.build(crypto_provider).err().unwrap();
|
||||
assert!(matches!(err.kind, ErrorKind::Field));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_sign_missing_server_ephemeral_key(
|
||||
attestation_config: &AttestationConfig,
|
||||
@@ -424,6 +359,7 @@ mod test {
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(attestation_config)
|
||||
@@ -434,11 +370,9 @@ mod test {
|
||||
connection_info, ..
|
||||
} = connection;
|
||||
|
||||
attestation_builder
|
||||
.connection_info(connection_info)
|
||||
.encoder_secret(encoder_secret());
|
||||
attestation_builder.connection_info(connection_info);
|
||||
|
||||
let err = attestation_builder.build(crypto_provider).err().unwrap();
|
||||
let err = attestation_builder.build(crypto_provider).unwrap_err();
|
||||
assert!(matches!(err.kind, ErrorKind::Field));
|
||||
}
|
||||
|
||||
@@ -455,6 +389,7 @@ mod test {
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
Vec::new(),
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(attestation_config)
|
||||
@@ -470,11 +405,80 @@ mod test {
|
||||
..
|
||||
}) = server_cert_data.handshake;
|
||||
|
||||
attestation_builder
|
||||
.server_ephemeral_key(server_ephemeral_key)
|
||||
.encoder_secret(encoder_secret());
|
||||
attestation_builder.server_ephemeral_key(server_ephemeral_key);
|
||||
|
||||
let err = attestation_builder.build(crypto_provider).err().unwrap();
|
||||
let err = attestation_builder.build(crypto_provider).unwrap_err();
|
||||
assert!(matches!(err.kind, ErrorKind::Field));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_reject_extensions_by_default(
|
||||
attestation_config: &AttestationConfig,
|
||||
) {
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
vec![Extension {
|
||||
id: b"foo".to_vec(),
|
||||
value: b"bar".to_vec(),
|
||||
}],
|
||||
);
|
||||
|
||||
let err = Attestation::builder(attestation_config)
|
||||
.accept_request(request)
|
||||
.unwrap_err();
|
||||
|
||||
assert!(matches!(err.kind, ErrorKind::Extension));
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
fn test_attestation_builder_accept_extension(crypto_provider: &CryptoProvider) {
|
||||
let attestation_config = AttestationConfig::builder()
|
||||
.supported_signature_algs([SignatureAlgId::SECP256K1])
|
||||
.extension_validator(|_| Ok(()))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let transcript = Transcript::new(GET_WITH_HEADER, OK_JSON);
|
||||
let connection = ConnectionFixture::tlsnotary(transcript.length());
|
||||
|
||||
let RequestFixture { request, .. } = request_fixture(
|
||||
transcript,
|
||||
encoding_provider(GET_WITH_HEADER, OK_JSON),
|
||||
connection.clone(),
|
||||
Blake3::default(),
|
||||
vec![Extension {
|
||||
id: b"foo".to_vec(),
|
||||
value: b"bar".to_vec(),
|
||||
}],
|
||||
);
|
||||
|
||||
let mut attestation_builder = Attestation::builder(&attestation_config)
|
||||
.accept_request(request)
|
||||
.unwrap();
|
||||
|
||||
let ConnectionFixture {
|
||||
server_cert_data,
|
||||
connection_info,
|
||||
..
|
||||
} = connection;
|
||||
|
||||
let HandshakeData::V1_2(HandshakeDataV1_2 {
|
||||
server_ephemeral_key,
|
||||
..
|
||||
}) = server_cert_data.handshake;
|
||||
|
||||
attestation_builder
|
||||
.connection_info(connection_info)
|
||||
.server_ephemeral_key(server_ephemeral_key);
|
||||
|
||||
let attestation = attestation_builder.build(crypto_provider).unwrap();
|
||||
|
||||
assert_eq!(attestation.body.extensions().count(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
use std::{fmt::Debug, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
attestation::FieldKind,
|
||||
attestation::{Extension, InvalidExtension},
|
||||
hash::{HashAlgId, DEFAULT_SUPPORTED_HASH_ALGS},
|
||||
signing::SignatureAlgId,
|
||||
};
|
||||
|
||||
const DEFAULT_SUPPORTED_FIELDS: &[FieldKind] = &[
|
||||
FieldKind::ConnectionInfo,
|
||||
FieldKind::ServerEphemKey,
|
||||
FieldKind::ServerIdentityCommitment,
|
||||
FieldKind::EncodingCommitment,
|
||||
];
|
||||
type ExtensionValidator = Arc<dyn Fn(&[Extension]) -> Result<(), InvalidExtension> + Send + Sync>;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[allow(dead_code)]
|
||||
@@ -44,11 +41,11 @@ impl AttestationConfigError {
|
||||
}
|
||||
|
||||
/// Attestation configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone)]
|
||||
pub struct AttestationConfig {
|
||||
supported_signature_algs: Vec<SignatureAlgId>,
|
||||
supported_hash_algs: Vec<HashAlgId>,
|
||||
supported_fields: Vec<FieldKind>,
|
||||
extension_validator: Option<ExtensionValidator>,
|
||||
}
|
||||
|
||||
impl AttestationConfig {
|
||||
@@ -65,17 +62,25 @@ impl AttestationConfig {
|
||||
&self.supported_hash_algs
|
||||
}
|
||||
|
||||
pub(crate) fn supported_fields(&self) -> &[FieldKind] {
|
||||
&self.supported_fields
|
||||
pub(crate) fn extension_validator(&self) -> Option<&ExtensionValidator> {
|
||||
self.extension_validator.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for AttestationConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AttestationConfig")
|
||||
.field("supported_signature_algs", &self.supported_signature_algs)
|
||||
.field("supported_hash_algs", &self.supported_hash_algs)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for [`AttestationConfig`].
|
||||
#[derive(Debug)]
|
||||
pub struct AttestationConfigBuilder {
|
||||
supported_signature_algs: Vec<SignatureAlgId>,
|
||||
supported_hash_algs: Vec<HashAlgId>,
|
||||
supported_fields: Vec<FieldKind>,
|
||||
extension_validator: Option<ExtensionValidator>,
|
||||
}
|
||||
|
||||
impl Default for AttestationConfigBuilder {
|
||||
@@ -83,7 +88,15 @@ impl Default for AttestationConfigBuilder {
|
||||
Self {
|
||||
supported_signature_algs: Vec::default(),
|
||||
supported_hash_algs: DEFAULT_SUPPORTED_HASH_ALGS.to_vec(),
|
||||
supported_fields: DEFAULT_SUPPORTED_FIELDS.to_vec(),
|
||||
extension_validator: Some(Arc::new(|e| {
|
||||
if !e.is_empty() {
|
||||
Err(InvalidExtension::new(
|
||||
"all extensions are disallowed by default",
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
})),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,9 +120,26 @@ impl AttestationConfigBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the supported attestation fields.
|
||||
pub fn supported_fields(&mut self, supported_fields: impl Into<Vec<FieldKind>>) -> &mut Self {
|
||||
self.supported_fields = supported_fields.into();
|
||||
/// Sets the extension validator.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// # use tlsn_core::attestation::{AttestationConfig, InvalidExtension};
|
||||
/// # let mut builder = AttestationConfig::builder();
|
||||
/// builder.extension_validator(|extensions| {
|
||||
/// for extension in extensions {
|
||||
/// if extension.id != b"example.type" {
|
||||
/// return Err(InvalidExtension::new("invalid extension type"));
|
||||
/// }
|
||||
/// }
|
||||
/// Ok(())
|
||||
/// });
|
||||
/// ```
|
||||
pub fn extension_validator<F>(&mut self, f: F) -> &mut Self
|
||||
where
|
||||
F: Fn(&[Extension]) -> Result<(), InvalidExtension> + Send + Sync + 'static,
|
||||
{
|
||||
self.extension_validator = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -118,7 +148,16 @@ impl AttestationConfigBuilder {
|
||||
Ok(AttestationConfig {
|
||||
supported_signature_algs: self.supported_signature_algs.clone(),
|
||||
supported_hash_algs: self.supported_hash_algs.clone(),
|
||||
supported_fields: self.supported_fields.clone(),
|
||||
extension_validator: self.extension_validator.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for AttestationConfigBuilder {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AttestationConfigBuilder")
|
||||
.field("supported_signature_algs", &self.supported_signature_algs)
|
||||
.field("supported_hash_algs", &self.supported_hash_algs)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user