🎉 Open-source Scatterbrained

This commit is contained in:
Jordan Matelsky
2021-12-14 14:23:36 -05:00
commit bead942a3d
48 changed files with 3488 additions and 0 deletions

5
.flake8 Normal file
View File

@@ -0,0 +1,5 @@
[flake8]
max-line-length = 120
extend-ignore = E203,E501,W503
max-complexity = 12
select = B,C,E,F,W,B9

427
.gitignore vendored Normal file
View File

@@ -0,0 +1,427 @@
# Created by https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm
### Emacs ###
# -*- mode: gitignore; -*-
*~
\#*\#
/.emacs.desktop
/.emacs.desktop.lock
*.elc
auto-save-list
tramp
.\#*
# Org-mode
.org-id-locations
*_archive
ltximg/**
# flymake-mode
*_flymake.*
# eshell files
/eshell/history
/eshell/lastdir
# elpa packages
/elpa/
# reftex files
*.rel
# AUCTeX auto folder
/auto/
# cask packages
.cask/
dist/
# Flycheck
flycheck_*.el
# server auth directory
/server/
# projectiles files
.projectile
# directory configuration
.dir-locals.el
# network security
/network-security.data
### Linux ###
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### PyCharm ###
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/artifacts
# .idea/compiler.xml
# .idea/jarRepositories.xml
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
### PyCharm Patch ###
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
# *.iml
# modules.xml
# .idea/misc.xml
# *.ipr
# Sonarlint plugin
# https://plugins.jetbrains.com/plugin/7973-sonarlint
.idea/**/sonarlint/
# SonarQube Plugin
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
.idea/**/sonarIssues.xml
# Markdown Navigator plugin
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
.idea/**/markdown-navigator.xml
.idea/**/markdown-navigator-enh.xml
.idea/**/markdown-navigator/
# Cache file creation bug
# See https://youtrack.jetbrains.com/issue/JBR-2257
.idea/$CACHE_FILE$
# CodeStream plugin
# https://plugins.jetbrains.com/plugin/12206-codestream
.idea/codestream.xml
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
pytestdebug.log
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
doc/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pythonenv*
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# profiling data
.prof
### SublimeText ###
# Cache files for Sublime Text
*.tmlanguage.cache
*.tmPreferences.cache
*.stTheme.cache
# Workspace files are user-specific
*.sublime-workspace
# Project files should be checked into the repository, unless a significant
# proportion of contributors will probably not be using Sublime Text
# *.sublime-project
# SFTP configuration file
sftp-config.json
# Package control specific files
Package Control.last-run
Package Control.ca-list
Package Control.ca-bundle
Package Control.system-ca-bundle
Package Control.cache/
Package Control.ca-certs/
Package Control.merged-ca-bundle
Package Control.user-ca-bundle
oscrypto-ca-bundle.crt
bh_unicode_properties.cache
# Sublime-github package stores a github token in this file
# https://packagecontrol.io/packages/sublime-github
GitHub.sublime-settings
### Vim ###
# Swap
[._]*.s[a-v][a-z]
!*.svg # comment out if you don't need vector files
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]
# Session
Session.vim
Sessionx.vim
# Temporary
.netrwhist
# Auto-generated tag files
tags
# Persistent undo
[._]*.un~
### VisualStudioCode ###
.vscode/
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm

38
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,38 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-json
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/isort
rev: 5.9.2
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 21.7b0
hooks:
- id: black
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
additional_dependencies: [flake8-bugbear]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.0
hooks:
- id: bandit
args: ["-x", "*/**/*_test.py"]
- repo: local
hooks:
- id: pytest
name: pytest
entry: pytest src/scatterbrained
language: system
pass_filenames: false
types: [python]

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2021 The Johns Hopkins Applied Physics Laboratory
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

62
README.md Normal file
View File

@@ -0,0 +1,62 @@
<h1 align='center'>Scatterbrained</h1>
<p align='center'>Decentralized Federated Learning</p>
<p align='center'>
<a href="https://pypi.org/project/scatterbrained/"><img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/scatterbrained?style=for-the-badge"></a>
<a href="https://github.com/JHUAPL/scatterbrained"><img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/JHUAPL/scatterbrained?style=for-the-badge"></a>
<a href="https://www.apache.org/licenses/LICENSE-2.0"><img alt="GitHub" src="https://img.shields.io/github/license/JHUAPL/scatterbrained?style=for-the-badge"></a>
</p>
Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated learning, Scatterbrained supports decentralized federated learning — a new, cooperative type of federated learning where the learning is done by a group of peers instead of by a centralized server. For more information, see our 2021 paper, [_Scatterbrained: A flexible and expandable pattern for decentralized machine learning_](#).
You can use your favorite machine learning frameworks alongside Scatterbrained, such as TensorFlow, SciKit-Learn, or PyTorch.
## Usage
For examples of how to get started using Scatterbrained, see the [Examples](examples/) directory.
## Installation
You can install Scatterbrained with pip:
```shell
pip install scatterbrained
```
If you would rather download and install from source, you can do so with the following:
```shell
git clone https://github.com/JHUAPL/scatterbrained.git
cd scatterbrained
```
You must first install the dependencies with:
```shell
pip3 install -r ./requirements/requirements.txt
```
And then you can install the package with:
```shell
pip3 install -e .
```
## License
The code in this repository is released under an Apache 2.0 license. For more information, see [LICENSE](LICENSE).
> Copyright 2021 The Johns Hopkins Applied Physics Laboratory
>
> Licensed under the Apache License, Version 2.0 (the "License");
> you may not use this file except in compliance with the License.
> You may obtain a copy of the License at
>
> http://www.apache.org/licenses/LICENSE-2.0
>
> Unless required by applicable law or agreed to in writing, software
> distributed under the License is distributed on an "AS IS" BASIS,
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
> See the License for the specific language governing permissions and
> limitations under the License.

44
docs/Installation.md Normal file
View File

@@ -0,0 +1,44 @@
# Installation Guide
Installing Scatterbrained on your development machine can be done either through pip or by downloading the source code and installing it locally.
## Pip-based Installation
```shell
pip install scatterbrained
```
This will automatically install the latest stable version of Scatterbrained, along with its dependencies. For a complete list of dependencies, see the `requirements/` directory in this repository.
## Source-based Installation
You can also clone the git repository and install Scatterbrained from source. In this case, you will also need to manually install the dependencies.
```shell
git clone https://github.com/JHUAPL/scatterbrained.git
cd scatterbrained
# Install dependencies:
pip install -r requirements/requirements.txt
# Install the library in "editable"/development mode:
pip install -e .
```
## Container-based Installation
It is also possible to install and Scatterbrained in a Docker container. This can be useful for running multiple instances of Scatterbrained on the same machine, for testing or benchmarking.
Docker-based instructions coming soon.
## Installation Troubleshooting
Click the carrot next to each question to see the full troubleshooting guide.
<details>
<summary><b>Errors when installing on Ubuntu ≤14</b></summary>
This is likely due to an older version of 0MQ installed with `apt-get`. If you're sure you don't have software that relies upon this older version, then run the following commands to remove the old version of 0MQ and install the latest version of 0MQ:
sudo apt-get remove libzmq-dev python-zmq
</details>

View File

@@ -0,0 +1,112 @@
# Getting Started with Scatterbrained
This is a quickstart tutorial for the Scatterbrained federated learning library. It will assume that you have already installed the library and have a working, high-level understanding of federated learning.
* For installation instructions, see [Installation](../installation/README.md).
* For a refresher on federated learning, see this [Wikipedia article](https://en.wikipedia.org/wiki/Federated_learning).
## Step 1: Building a model
For this example, we're going to build a basic linear regression model using `torch`. (You can also use other machine learning frameworks, such as TensorFlow or sklearn!) We'll load a simple built-in dataset, and train a linear regression model on it.
```python
import sklearn.datasets
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch import optim
# Load the data
X, y = sklearn.datasets.load_diabetes(return_X_y=True)
# Generate train/test splits
(X_train, X_val, y_train, y_val) = train_test_split(
X, y, test_size=0.2
)
# Create the data loaders for torch
train = DataLoader(
TensorDataset(
torch.from_numpy(X_train), torch.from_numpy(y_train)
),
batch_size=32, shuffle=True
)
val = DataLoader(
TensorDataset(
torch.from_numpy(X_val), torch.from_numpy(y_val)
),
batch_size=32, shuffle=True
)
# Create the model
model = nn.Linear(X_train.shape[1], 1)
create_optimizer = lambda model: optim.AdamW(model.parameters())
```
In the codeblock above, we've created a model just like we would in a traditional ML pipeline. We've also created a function that returns an optimizer. So far, we're not doing anything "FL-flavored": This is just plain old ML.
## Step 2: Training the model
It's now time to introduce the Scatterbrained library. Training a model is still quite simple:
```python
import scatterbrained as sb
NUM_EPOCHS = 10
# Create a new scatterbrained Node to hold all of the logic
# for a federated learning compute node:
async with sb.Node() as node:
# Create a new Namespace. This is a unique identifier
# shared by all nodes in the same FL community. Other
# nodes on the same network that know this name can join
# your FL cluster:
async with node.namespace(
"MyFirstNamespace", model, create_optimizer
) as ns:
# Finally we can ask scatterbrained to perform the
# training loop for us in a background thread:
await ns.train(NUM_EPOCHS, trainloader, validloader)
# At the same time, we can also serve our node's
# resources to other nodes on the network:
await ns.serve()
```
## Step 3: Joining the cluster
From another machine on the same network (or the same machine on a different port), you can join the cluster by running the following code.
Unlike the code blocks above, where we built up a file gradually, this is all the code you need to run from your second machine:
```python
import scatterbrained as sb
async with sb.Node() as node:
async with node.join("MyFirstNamespace") as ns:
await ns.sync_model() # This line is different!
await ns.serve()
```
Wow, that's succinct! Let's take a look at what's happening behind the scenes when we run this code.
First, we create a new scatterbrained Node, just like before. And just like in the first example, we create a new Namespace with the same name, so that the nodes know they're allowed to talk to each other. (A sb.Node can't cooperatively learn with a Node training a different model, so we use the Namespace to indicate to the Node that this networked peer is speaking the same language — i.e., using the same model.)
But here we don't specify any training data or model architecture: Instead, we use the `sync_model` method so that this node can download the model from another peer Node on the network. This means that you can use Scatterbrained to quickly transit a model from one machine to another. (In other words, your models will always be in sync, with a single node serving as the source of truth.)
## Next Steps
In this tutorial, we've covered the basics of using Scatterbrained to train a model. But there are many other features that you can use to tailor your decentralized federated learning code. For example,
* You can emulate traditional, centralized federated learning if you want a single machine to serve as an authority
* You can train a model on multiple machines in parallel with different datasets
* You can specify different optimizers for different Nodes
* You can change your network topology so that nodes can communicate with only certain peers, and will ignore others
* You can design custom network topologies so that information can only flow in certain directions

12
docs/reference/README.md Normal file
View File

@@ -0,0 +1,12 @@
# Reference documentation for the Scatterbrained Library
<!--
Documentation was generated with `docshund`:
```
$ pip install docshund
$ docshund ./src
```
-->

View File

View File

@@ -0,0 +1,74 @@
## *Class* `DiscoveryEngine`
A `DiscoveryEngine` is a manager class for the process of identifying and communicating with other peers in the Scatterbrained network.
The `DiscoveryEngine` is responsible for maintaining a list of peers, and periodically sending heartbeats to the network, which other peers can use to determine if the sending node is still alive.
Identities of peers are stored in `DiscoveryEngine.peers`.
Identities of the current node are stored in `DiscoveryEngine.identities` and can be manually added or removed using `DiscoveryEngine.add_identity` and `DiscoveryEngine.remove_identity`.
## *Function* `heartbeat(self)`
Retrieve the heartbeat interval in seconds.
## *Function* `heartbeat(self, value)`
Set the heartbeat interval in seconds.
### Arguments
> - **value** (`float`: `None`): The heartbeat interval in seconds.
### Returns
None
## *Function* `peers(self)`
Retrieve the list of peers.
### Returns
> - **Set[scatterbrained.discovery.types.Identity]** (`None`: `None`): The list of peers.
## *Function* `add_identity(self, identity: Identity) -> None`
Add an identity to the list of identities for the current node.
### Arguments
> - **identity** (`scatterbrained.discovery.types.Identity`: `None`): The identity to add.
### Returns
None
## *Function* `remove_identity(self, identity: Identity) -> None`
Remove an identity from the list of valid identities for the node.
### Arguments
> - **identity** (`scatterbrained.discovery.types.Identity`: `None`): The identity to remove.
### Returns
None
## *Function* `stop(self)`
Gracefully stop the `DiscoveryEngine`.
### Returns
None

View File

@@ -0,0 +1,46 @@
## *Class* `Publisher(Protocol)`
A Publisher is a class that can publish a message to a set of peers.
Note that this is a protocol and shouldn't be used directly!
## *Function* `publish(self, data: bytes) -> None`
Publish the given payload to a set of peers.
## *Function* `open(self) -> None`
Open the underlying connection mechanism, enabling this instance to send messages.
## *Function* `close(self) -> None`
Close the underlying connection mechanism, stopping this instance from sending messages.
## *Class* `Subscriber(Protocol)`
A Subscriber is a class that can subscribe to messages from a set of peers.
Note that this is a protocol and shouldn't be used directly!
## *Function* `open(self) -> None`
Open the underlying connection mechanism, enabling this instance to receive messages.
## *Function* `close(self) -> None`
Close the underlying connection mechanism, stopping this instance from receiving messages.

View File

View File

View File

View File

View File

View File

View File

View File

View File

@@ -0,0 +1,164 @@
## *Class* `OperatingMode(Enum)`
The mode for the Node to operate in.
> - **Leech** (`None`: `None`): Listen for broadcasts; do not share.
> - **Offline** (`None`: `None`): Do not listen or broadcast.
> - **Peer** (`None`: `None`): Listen and broadcast.
> - **Seeding** (`None`: `None`): Broadcast but do not listen.
## *Class* `Namespace`
A Namespace is a collection of Scatterbrained nodes that share a common model or parameter state space. When connecting or disconnecting from a community, a Scatterbrained node joins a named namespace, or creates a new namespace with a unique name.
Note that this class shouldn't be instantiated directly, rather instances should created via the `scatterbrained.Node.namespace` method.
## *Function* `operating_mode(self)`
Get the operating mode for this namespace.
### Returns
> - **scatterbrained.OperatingMode** (`None`: `None`): The operating mode for this namespace.
## *Function* `launch(self)`
Launch the Namespace.
This is usually done by entering the `Namespace` context manager. You should not need to call this method directly.
## *Function* `close(self)`
Close out the Namespace gracefully.
This is usually done by exiting the `Namespace` context manager. You should not need to call this method directly.
## *Function* `connect_to(self, peer: Identity) -> bool`
Connect to a peer in this Namespace by its Identity.
The method will not connect to the given peer if the `OperatingMode` of this Namespace is either `LEECHING` or `OFFLINE`, or the `Identity` is rejected by this instance's `peer_filter`.
### Arguments
> - **peer** (`scatterbrained.Identity`: `None`): The peer to connect to.
### Returns
> - **bool** (`None`: `None`): `True` if the connection was successful, `False` otherwise.
## *Function* `disconnect_from(self, peer: Identity) -> bool`
Disconnect from a peer in this `Namespace` by its `Identity`.
### Arguments
> - **peer** (`scatterbrained.Identity`: `None`): The peer to disconnect from.
### Returns
> - **bool** (`None`: `None`): `True` if the disconnection was successful, `False` otherwise.
## *Function* `send_to(self, peer: Identity, *payload: bytes) -> None`
Send a byte sequence payload to a peer by its `Identity`.
### Arguments
> - **peer** (`scatterbrained.Identity`: `None`): The peer to send to.
> - **payload** (`bytes`: `None`): The payload to send.
### Returns
None
## *Function* `recv(self, timeout: Optional[float] = None) -> Tuple[Identity, Sequence[bytes]]`
Receive a message from the network.
This will block until a message is received, or the timeout is reached.
### Arguments
> - **timeout** (`float`: `None`): The timeout in seconds to wait for a message to be received.
### Returns
> - **bytes]** (`None`: `None`): The sender's Identity and the payload.
## *Class* `Node`
Scatterbrained peer node.
Manages all networking infrastructure, providing an interface for other classes to establish connections, and send and receive data.
## *Function* `listening(self)`
Whether or not the `Node` is listening for incoming connections.
### Returns
> - **bool** (`None`: `None`): Whether or not the Node is listening for incoming connections.
## *Function* `launch(self) -> None`
Launch the `Node`.
This will start the `DiscoveryEngine` and `NetworkEngine` and begin listening for new connections.
Note that this method is usually called automatically when entering an `async with` block and is not meant to be called manually.
### Returns
None
## *Function* `close(self)`
Gracefully close the `Node`.
This will close the `DiscoveryEngine` and `NetworkEngine`.
Note that this method is usually called automatically when exiting an `async with` block and is not meant to be called manually.
### Returns
None
## *Function* `namespace(self, name: str, *args, **kwargs) -> Namespace`
Gets the namespace with the given name, or creates a new one.
Also accepts the same arguments as `scatterbrained.Namespace` if creating a new namespace.
### Arguments
> - **name** (`str`: `None`): The name of the namespace.
### Returns
> - **scatterbrained.Namespace** (`None`: `None`): The namespace with the given name.

View File

@@ -0,0 +1,10 @@
## *Class* `Identity`
Represents the identity of a scatterbrained.Node (either local or remote) for a particular namespace.
### Args
> - **id** (`str`: `None`): The id of the Node.
> - **namespace** (`str`: `None`): The namespace the Node is operating in.
> - **host** (`str`: `None`): The advertised address of this Node.
> - **port** (`int`: `None`): The advertised port of this Node.

3
examples/README.md Normal file
View File

@@ -0,0 +1,3 @@
# Scatterbrained: Usage Examples
This directory contains examples of how to use the Scatterbrained federated learning library. For a quick getting-started guide, see the [Getting Started](../docs/getting-started/README.md) tutorial.

View File

@@ -0,0 +1,63 @@
import asyncio
import dataclasses
import json
from loguru import logger
import scatterbrained as sb
async def on_appear(v):
await asyncio.sleep(0.1)
logger.info(f"Appear: {v}")
async def on_disappear(v):
await asyncio.sleep(0.1)
logger.info(f"Disappear: {v}")
async def on_error(e):
await asyncio.sleep(0.1)
logger.opt(exception=e).error("local error")
async def on_remote_recv(v):
logger.info(f"Remote: {v}")
async def on_remote_error(e):
logger.opt(exception=e).error("remote error")
async def main():
# NOTE: in a real deployment you'd want everything to use the same port, but because we're running on the
# same system here, we need to bind to different ports.
local_pub = sb.discovery.udp.UDPBroadcaster("127.0.0.1", port=9002)
local_sub = sb.discovery.udp.UDPReceiver("127.0.0.1", port=9001)
# Fake a remote node.
remote_pub = sb.discovery.udp.UDPBroadcaster("127.0.0.1", port=9001)
remote_sub = sb.discovery.udp.UDPReceiver("127.0.0.1", port=9002)
await asyncio.wait([local_pub.open(), local_sub.open(), remote_pub.open(), remote_sub.open()])
engine = sb.discovery.DiscoveryEngine(
local_pub,
local_sub,
identities=[sb.types.Identity(id="baz", namespace="bar", host="omg", port=3223)],
heartbeat=2,
)
await engine.start(on_appear=on_appear, on_disappear=on_disappear, on_error=on_error)
peer = sb.types.Identity(id="foo", namespace="bar", host="meme", port=32233)
remote_sub.subscribe(on_recv=on_remote_recv, on_error=on_error)
await remote_pub.publish(json.dumps(dataclasses.asdict(peer)).encode())
await asyncio.sleep(15)
await engine.stop()
await asyncio.wait([local_pub.close(), local_sub.close(), remote_pub.close(), remote_sub.close()])
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,48 @@
import asyncio
import scatterbrained as sb
async def send_msgs(engine, identity, peer, n=10):
for i in range(n):
await engine.send_to(identity, peer, f"hello there x{i}".encode())
await asyncio.sleep(1.0)
async def main():
id1 = sb.types.Identity("foo", "baz", "127.0.0.1", 9001)
id2 = sb.types.Identity("bar", "baz", "127.0.0.1", 9002)
rx1 = sb.network.ZMQReceiver()
rx2 = sb.network.ZMQReceiver()
ne1 = sb.network.NetworkEngine(rx1, lambda: sb.network.ZMQTransmitter("foo"))
ne2 = sb.network.NetworkEngine(rx2, lambda: sb.network.ZMQTransmitter("bar"))
await ne1.bind(id1.host, id1.port)
await ne2.bind(id2.host, id2.port)
await ne1.connect_to(id2)
await ne2.connect_to(id1)
async def on_recv(identity, payload):
print("RECV:", identity, payload)
async def on_malformed(peer_id, segments):
print("MALFORMED:", peer_id, segments)
async def on_error(ex):
print("ERROR", ex)
d1 = ne1.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
d2 = ne2.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
await asyncio.wait([send_msgs(ne1, id1, id2), send_msgs(ne2, id2, id1)])
await asyncio.sleep(1.0)
await d1.dispose()
await d2.dispose()
if __name__ == "__main__":
asyncio.run(main())

32
examples/node.py Normal file
View File

@@ -0,0 +1,32 @@
import asyncio
import scatterbrained as sb
async def main():
# NOTE: in a real deployment you'd want everything to use the same port, but because we're running on the
# same system here, we need to bind to different ports.
de1 = sb.discovery.DiscoveryEngine(
publisher=sb.discovery.UDPBroadcaster("127.0.0.1", port=9002),
subscriber=sb.discovery.UDPReceiver("127.0.0.1", port=9001),
heartbeat=2,
)
de2 = sb.discovery.DiscoveryEngine(
publisher=sb.discovery.UDPBroadcaster("127.0.0.1", port=9001),
subscriber=sb.discovery.UDPReceiver("127.0.0.1", port=9002),
heartbeat=2,
)
async with sb.Node(id="foo", host="127.0.0.1", discovery_engine=de1) as node1, sb.Node(
id="bar", host="127.0.0.1", discovery_engine=de2
) as node2:
async with node1.namespace(name="foobar") as ns1, node2.namespace(name="foobar") as ns2:
await asyncio.gather(ns1.wait_for_peers(peers="bar"), ns2.wait_for_peers(peers="foo"))
await ns1.send_to(ns2._id, b"hello")
sender, payload = await ns2.recv(5.0)
print(sender, payload)
await asyncio.sleep(10)
await asyncio.sleep(15)
if __name__ == "__main__":
asyncio.run(main())

18
pyproject.toml Normal file
View File

@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 120
target-version = ["py38"]
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
line_length = 120
[tool.pytest.ini_options]
markers = ["unit", "integration", "end2end"]

View File

@@ -0,0 +1,12 @@
-c requirements.txt
bandit
black
docshund
flake8
flake8-bugbear
isort
notebook
pre-commit
pytest
pytest-asyncio
safety

View File

@@ -0,0 +1,266 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile dev-requirements.in
#
appnope==0.1.2
# via
# ipykernel
# ipython
argon2-cffi==21.1.0
# via notebook
attrs==21.2.0
# via
# flake8-bugbear
# jsonschema
# pytest
backcall==0.2.0
# via ipython
backports.entry-points-selectable==1.1.0
# via virtualenv
bandit==1.7.0
# via -r dev-requirements.in
black==21.9b0
# via -r dev-requirements.in
bleach==4.1.0
# via nbconvert
certifi==2021.5.30
# via requests
cffi==1.14.6
# via argon2-cffi
cfgv==3.3.1
# via pre-commit
charset-normalizer==2.0.6
# via requests
click==8.0.1
# via
# black
# docshund
# safety
debugpy==1.4.3
# via ipykernel
decorator==5.1.0
# via ipython
defusedxml==0.7.1
# via nbconvert
distlib==0.3.3
# via virtualenv
docshund==0.1.2
# via -r dev-requirements.in
dparse==0.5.1
# via safety
entrypoints==0.3
# via
# jupyter-client
# nbconvert
filelock==3.0.12
# via virtualenv
flake8==3.9.2
# via
# -r dev-requirements.in
# flake8-bugbear
flake8-bugbear==21.9.1
# via -r dev-requirements.in
gitdb==4.0.7
# via gitpython
gitpython==3.1.24
# via bandit
identify==2.2.15
# via pre-commit
idna==3.2
# via requests
iniconfig==1.1.1
# via pytest
ipykernel==6.4.1
# via notebook
ipython==7.27.0
# via ipykernel
ipython-genutils==0.2.0
# via
# ipykernel
# nbformat
# notebook
isort==5.9.3
# via -r dev-requirements.in
jedi==0.18.0
# via ipython
jinja2==3.0.1
# via
# nbconvert
# notebook
jsonschema==3.2.0
# via nbformat
jupyter-client==7.0.3
# via
# ipykernel
# nbclient
# notebook
jupyter-core==4.8.1
# via
# jupyter-client
# nbconvert
# nbformat
# notebook
jupyterlab-pygments==0.1.2
# via nbconvert
markupsafe==2.0.1
# via jinja2
matplotlib-inline==0.1.3
# via
# ipykernel
# ipython
mccabe==0.6.1
# via flake8
mistune==0.8.4
# via nbconvert
mypy-extensions==0.4.3
# via black
nbclient==0.5.4
# via nbconvert
nbconvert==6.1.0
# via notebook
nbformat==5.1.3
# via
# nbclient
# nbconvert
# notebook
nest-asyncio==1.5.1
# via
# jupyter-client
# nbclient
nodeenv==1.6.0
# via pre-commit
notebook==6.4.4
# via -r dev-requirements.in
packaging==21.0
# via
# bleach
# dparse
# pytest
# safety
pandocfilters==1.5.0
# via nbconvert
parso==0.8.2
# via jedi
pathspec==0.9.0
# via black
pbr==5.6.0
# via stevedore
pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
platformdirs==2.3.0
# via
# black
# virtualenv
pluggy==1.0.0
# via pytest
pre-commit==2.15.0
# via -r dev-requirements.in
prometheus-client==0.11.0
# via notebook
prompt-toolkit==3.0.20
# via ipython
ptyprocess==0.7.0
# via
# pexpect
# terminado
py==1.10.0
# via pytest
pycodestyle==2.7.0
# via flake8
pycparser==2.20
# via cffi
pyflakes==2.3.1
# via flake8
pygments==2.10.0
# via
# ipython
# jupyterlab-pygments
# nbconvert
pyparsing==2.4.7
# via packaging
pyrsistent==0.18.0
# via jsonschema
pytest==6.2.5
# via
# -r dev-requirements.in
# pytest-asyncio
pytest-asyncio==0.15.1
# via -r dev-requirements.in
python-dateutil==2.8.2
# via jupyter-client
pyyaml==5.4.1
# via
# bandit
# dparse
# pre-commit
pyzmq==22.3.0
# via
# -c requirements.txt
# jupyter-client
# notebook
regex==2021.9.24
# via black
requests==2.26.0
# via safety
safety==1.10.3
# via -r dev-requirements.in
send2trash==1.8.0
# via notebook
six==1.16.0
# via
# bandit
# bleach
# jsonschema
# python-dateutil
# virtualenv
smmap==4.0.0
# via gitdb
stevedore==3.4.0
# via bandit
terminado==0.12.1
# via notebook
testpath==0.5.0
# via nbconvert
toml==0.10.2
# via
# dparse
# pre-commit
# pytest
tomli==1.2.1
# via black
tornado==6.1
# via
# ipykernel
# jupyter-client
# notebook
# terminado
traitlets==5.1.0
# via
# ipykernel
# ipython
# jupyter-client
# jupyter-core
# matplotlib-inline
# nbclient
# nbconvert
# nbformat
# notebook
typing-extensions==3.10.0.2
# via
# black
# gitpython
urllib3==1.26.7
# via requests
virtualenv==20.8.1
# via pre-commit
wcwidth==0.2.5
# via prompt-toolkit
webencodings==0.5.1
# via bleach
# The following packages are considered to be unsafe in a requirements file:
# setuptools

View File

@@ -0,0 +1,5 @@
loguru
orjson
pyzmq
rx
uvloop

View File

@@ -0,0 +1,16 @@
#
# This file is autogenerated by pip-compile with python 3.8
# To update, run:
#
# pip-compile requirements.in
#
loguru==0.5.3
# via -r requirements.in
orjson==3.6.3
# via -r requirements.in
pyzmq==22.3.0
# via -r requirements.in
rx==3.2.0
# via -r requirements.in
uvloop==0.16.0
# via -r requirements.in

32
setup.py Normal file
View File

@@ -0,0 +1,32 @@
from pathlib import Path
from setuptools import find_packages, setup
HERE = Path(__file__).parent
README = HERE.joinpath("README.md").read_text()
REQUIREMENTS = HERE.joinpath("requirements", "requirements.in").read_text().split()
def get_version(rel_path: Path):
contents = HERE.joinpath(rel_path).read_text().splitlines()
for line in contents:
if line.startswith("__version__"):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
else:
raise RuntimeError("Unable to find version string.")
setup(
name="scatterbrained",
version=get_version(Path("src", "scatterbrained", "version.py")),
author="Miller Wilt",
author_email="miller.wilt@jhuapl.edu",
description="Decentralized Federated Learning framework",
long_description=README,
long_description_content_type="text/markdown",
packages=find_packages("src"),
package_dir={"": "src"},
install_requires=REQUIREMENTS,
python_requires=">=3.8.0",
)

View File

@@ -0,0 +1,3 @@
from . import discovery, network, types # noqa: F401
from .node import Node # noqa: F401
from .version import __version__ # noqa: F401

View File

@@ -0,0 +1,3 @@
from . import types # noqa: F401
from .engine import DiscoveryEngine # noqa: F401
from .udp import UDPBroadcaster, UDPReceiver # noqa: F401

View File

@@ -0,0 +1,239 @@
"""
The Discovery Engine is responsible for discovering and managing nodes on the Scatterbrained network.
"""
import asyncio
import dataclasses
from datetime import datetime
from typing import Awaitable, Callable, Dict, Optional, Sequence, Set, Tuple
import orjson
import rx
from ..types import Identity
from .types import Publisher, Subscriber
class DiscoveryEngine:
"""
A `DiscoveryEngine` is a manager class for the process of identifying and communicating with other peers in the
Scatterbrained network.
The `DiscoveryEngine` is responsible for maintaining a list of peers, and periodically sending heartbeats to the
network, which other peers can use to determine if the sending node is still alive.
Identities of peers are stored in `DiscoveryEngine.peers`.
Identities of the current node are stored in `DiscoveryEngine.identities` and can be manually added or removed using
`DiscoveryEngine.add_identity` and `DiscoveryEngine.remove_identity`.
"""
_publisher: Publisher
_subscriber: Subscriber
_heartbeat: float
_identities: Dict[Tuple[str, str], Identity]
_subscription: Optional[rx.disposable.Disposable]
_heartbeat_task: Optional[asyncio.Task]
_lifetime_task: Optional[asyncio.Task]
_peers: Set[Identity]
_last_seen: Dict[Identity, datetime]
def __init__(
self, publisher: Publisher, subscriber: Subscriber, identities: Optional[Sequence] = None, heartbeat: float = 5
) -> None:
"""
Create a new DiscoveryEngine with a publisher and subscriber.
Arguments:
publisher (scatterbrained.discovery.types.Publisher): The publisher to use for sending peer transactions.
subscriber (scatterbrained.discovery.types.Subscriber): The subscriber to use for receiving peer
transactions.
identities (Optional[Sequence[scatterbrained.discovery.types.Identity]]): A list of identities to use for
the initial peer list.
heartbeat (float): The heartbeat interval in seconds.
Returns:
None
"""
if identities is None:
identities = []
self.heartbeat = heartbeat
self._publisher = publisher
self._subscriber = subscriber
self._identities = {(i.id, i.namespace): i for i in identities}
self._subscription = None
self._heartbeat_task = None
self._lifetime_task = None
self._peers = set()
self._last_seen = {}
@property
def heartbeat(self):
"""
Retrieve the heartbeat interval in seconds.
"""
return self._heartbeat
@heartbeat.setter
def heartbeat(self, value):
"""
Set the heartbeat interval in seconds.
Arguments:
value (float): The heartbeat interval in seconds.
Returns:
None
"""
if value < 1:
raise ValueError(f"value must be 1 or more (value={value})")
self._heartbeat = value
@property
def peers(self):
"""
Retrieve the list of peers.
Returns:
Set[scatterbrained.discovery.types.Identity]: The list of peers.
"""
# Defensive copy.
return set(self._peers)
def add_identity(self, identity: Identity) -> None:
"""
Add an identity to the list of identities for the current node.
Arguments:
identity (scatterbrained.discovery.types.Identity): The identity to add.
Returns:
None
"""
self._identities[(identity.id, identity.namespace)] = identity
def remove_identity(self, identity: Identity) -> None:
"""
Remove an identity from the list of valid identities for the node.
Arguments:
identity (scatterbrained.discovery.types.Identity): The identity to remove.
Returns:
None
"""
self._identities.pop((identity.id, identity.namespace), None)
async def start(
self,
on_appear: Callable[[Identity], Awaitable[None]],
on_disappear: Callable[[Identity], Awaitable[None]],
on_error: Callable[[Exception], Awaitable[None]],
) -> None:
"""
Start the DiscoveryEngine, with awaitable callbacks per peer.
Callbacks receive an identity of the peer that appeared or disappeared. The `on_error` callback is called
if an error occurs, and it receives as arguments the exception. Note that all callbacks are async.
Arguments:
on_appear (callback): A callback to call when a peer appears on the scatterbrained network.
on_disappear (callback): A callback to call when a peer disappears from the network.
on_error (exception_handler): A callback to run when an exception is thrown communicating with a peer.
Returns:
None
"""
if self._subscription is not None:
return
async def handle_peer_heartbeat(data):
peer_info = orjson.loads(data)
peer = Identity(**peer_info)
# TODO: the proper way to check if a heartbeat was generated
# locally is to look at the source ip, and cross-check that with
# known interfaces. This will work for now, however.
if (peer.id, peer.namespace) in self._identities:
return
elif peer not in self._peers:
self._peers.add(peer)
self._last_seen[peer] = datetime.now()
await on_appear(peer)
else:
# Update time of last seen
self._last_seen[peer] = datetime.now()
await self._publisher.open()
await self._subscriber.open()
self._subscription = self._subscriber.subscribe(on_recv=handle_peer_heartbeat, on_error=on_error)
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
self._lifetime_task = asyncio.create_task(self._lifetime_monitor(on_disappear))
async def stop(self):
"""
Gracefully stop the `DiscoveryEngine`.
Returns:
None
"""
if self._subscription is None:
return
assert self._heartbeat_task is not None # nosec
assert self._lifetime_task is not None # nosec
self._subscription.dispose()
self._subscription = None
self._heartbeat_task.cancel()
await self._heartbeat_task
self._heartbeat_task = None
self._lifetime_task.cancel()
await self._lifetime_task
self._lifetime_task = None
await self._publisher.close()
await self._subscriber.close()
self._peers = set()
self._last_seen = {}
async def _heartbeat_loop(self):
try:
while True:
for id in self._identities.values():
obj = dataclasses.asdict(id)
obj = orjson.dumps(obj)
await self._publisher.publish(obj)
await asyncio.sleep(self._heartbeat)
except asyncio.CancelledError:
pass
async def _lifetime_monitor(self, on_disappear: Callable[[Identity], Awaitable[None]]):
try:
while True:
now = datetime.now()
tasks = []
to_remove = set()
for peer, last_seen in self._last_seen.items():
if (now - last_seen).total_seconds() >= self._heartbeat * 5:
self._peers.remove(peer)
to_remove.add(peer)
task = asyncio.create_task(on_disappear(peer))
tasks.append(task)
# NOTE: not optimal, as tasks may take a long time, but will suffice for now.
if tasks:
for peer in to_remove:
self._last_seen.pop(peer)
await asyncio.wait(tasks)
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass

View File

@@ -0,0 +1,70 @@
from typing import Awaitable, Callable, Optional, Protocol
class Publisher(Protocol):
"""
A Publisher is a class that can publish a message to a set of peers.
Note that this is a protocol and shouldn't be used directly!
"""
async def publish(self, data: bytes) -> None:
"""
Publish the given payload to a set of peers.
"""
...
async def open(self) -> None:
"""
Open the underlying connection mechanism, enabling this instance to send messages.
"""
...
async def close(self) -> None:
"""
Close the underlying connection mechanism, stopping this instance from sending messages.
"""
...
class Subscriber(Protocol):
"""
A Subscriber is a class that can subscribe to messages from a set of peers.
Note that this is a protocol and shouldn't be used directly!
"""
def subscribe(
self,
on_recv: Callable[[bytes], Awaitable[None]],
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
) -> None:
"""
Subscribe to messages from a set of peers, and attach async callbacks.
Arguments:
on_recv (Callable[[bytes], Awaitable[None]]): The callback to call when a message is received.
on_error (Optional[Callable[[Exception], Awaitable[None]]]): The callback to run when an error occurs.
Returns:
None
"""
...
async def open(self) -> None:
"""
Open the underlying connection mechanism, enabling this instance to receive messages.
"""
...
async def close(self) -> None:
"""
Close the underlying connection mechanism, stopping this instance from receiving messages.
"""
...
__all__ = ["Publisher", "Subscriber"]

View File

@@ -0,0 +1,145 @@
import asyncio
from typing import Awaitable, Callable, Optional, Tuple
import rx
from loguru import logger
from rx import operators as op
from rx.scheduler.eventloop import AsyncIOScheduler
from rx.subject import Subject
class _UDPBroadcastProtocol:
def connection_made(self, transport: asyncio.DatagramTransport):
self.transport = transport
def connection_lost(self, exc: Exception):
# TODO: determine how to inform larger program about any exceptions passed.
self.transport = None
def error_received(self, exc: OSError):
# TODO: determine how to inform larger program about any exceptions passed.
pass
class _UDPRecvProtocol:
def __init__(self, subject: Optional[Subject] = None):
if subject is None:
subject = Subject()
self.subject = subject
def connection_made(self, transport: asyncio.DatagramTransport):
self.transport = transport
def connection_lost(self, exc: Exception):
self.transport = None
if exc is None:
self.subject.on_completed()
else:
self.subject.on_error(exc)
def error_received(self, exc: OSError):
self.subject.on_error(exc)
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
self.subject.on_next((data, addr))
class UDPBroadcaster:
remote: str
port: int
_transport: Optional[asyncio.DatagramTransport]
_protocol: Optional[asyncio.DatagramProtocol]
def __init__(self, broadcast_addr: str = "255.255.255.255", port: int = 9001):
self.broadcast_addr = broadcast_addr
self.port = port
self._transport = None
self._protocol = None
async def open(self) -> None:
if self._transport is None:
loop = asyncio.get_running_loop()
self._transport, self._protocol = await loop.create_datagram_endpoint(
lambda: _UDPBroadcastProtocol(), remote_addr=(self.broadcast_addr, self.port), allow_broadcast=True
)
async def close(self) -> None:
if self._transport is not None:
self._transport.close()
self._transport = None
self._protocol = None
async def __aenter__(self) -> None:
await self.open()
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()
async def publish(self, data: bytes) -> None:
if self._transport is None:
raise RuntimeError("socket is not open")
self._transport.sendto(data)
class UDPReceiver:
source: str
port: int
_transport: Optional[asyncio.DatagramTransport]
_protocol: Optional[asyncio.DatagramProtocol]
def __init__(self, local_addr: str = "", port: int = 9001, subject: Optional[Subject] = None):
if subject is None:
subject = Subject()
self.local_addr = local_addr
self.port = port
self.subject = subject
self._transport = None
self._protocol = None
async def open(self) -> None:
if self._transport is None:
loop = asyncio.get_running_loop()
self._transport, self._protocol = await loop.create_datagram_endpoint(
lambda: _UDPRecvProtocol(subject=self.subject), local_addr=(self.local_addr, self.port)
)
async def close(self) -> None:
if self._transport is not None:
self._transport.close()
self._transport = None
self._protocol = None
async def __aenter__(self) -> None:
await self.open()
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()
def subscribe(
self,
on_recv: Callable[[bytes], Awaitable[None]],
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
):
def handle_next(obj):
data, _ = obj
return rx.from_future(asyncio.create_task(on_recv(data)))
def handle_error(ex, src):
assert on_error is not None # nosec
return rx.from_future(asyncio.create_task(on_error(ex)))
workflow = self.subject.pipe(op.flat_map(handle_next))
if on_error is not None:
workflow = workflow.pipe(op.catch(handle_error))
d = workflow.subscribe(
on_next=lambda _: logger.trace("message processed"),
on_error=lambda e: logger.opt(exception=e).error(f"{type(self).__name__} data stream error"),
on_completed=lambda: logger.trace(f"{type(self).__name__} data stream ended"),
scheduler=AsyncIOScheduler(loop=asyncio.get_running_loop()),
)
return d
__all__ = ["UDPBroadcaster", "UDPReceiver"]

View File

@@ -0,0 +1,154 @@
import asyncio
import socket
from unittest.mock import AsyncMock
import pytest
from . import udp
# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
@pytest.fixture
def addr():
return "127.0.0.1"
@pytest.fixture
def port():
return 9001
async def test_UDPBroadcaster_full(event_loop, addr, port):
# To simplify things, just use a blocking socket
recv_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
recv_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
recv_socket.settimeout(3)
await event_loop.run_in_executor(None, recv_socket.bind, ("", port))
broadcaster = udp.UDPBroadcaster(broadcast_addr=addr, port=port)
assert broadcaster._transport is None and broadcaster._protocol is None
await broadcaster.open()
assert broadcaster._transport is not None and broadcaster._protocol is not None
expected = b"hello world"
await broadcaster.publish(expected)
actual, _ = await event_loop.run_in_executor(None, recv_socket.recvfrom, 255)
assert actual == expected
await broadcaster.close()
assert broadcaster._transport is None and broadcaster._protocol is None
async def test_UDPBroadcaster_full_async_with(event_loop, addr, port):
recv_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
recv_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
recv_socket.settimeout(3)
await event_loop.run_in_executor(None, recv_socket.bind, ("", port))
broadcaster = udp.UDPBroadcaster(broadcast_addr=addr, port=port)
assert broadcaster._transport is None and broadcaster._protocol is None
async with broadcaster:
assert broadcaster._transport is not None and broadcaster._protocol is not None
expected = b"hello world"
await broadcaster.publish(expected)
actual, _ = await event_loop.run_in_executor(None, recv_socket.recvfrom, 255)
assert actual == expected
assert broadcaster._transport is None and broadcaster._protocol is None
async def test_UDPReceiver_full(event_loop, addr, port):
async def mock_on_recv(*args):
async with cond:
cond.notify()
# To simplify things, just use a blocking socket
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
receiver = udp.UDPReceiver(local_addr=addr, port=port)
assert receiver._transport is None and receiver._protocol is None
await receiver.open()
assert receiver._transport is not None and receiver._protocol is not None
cond = asyncio.Condition()
on_recv = AsyncMock(side_effect=mock_on_recv)
on_error = AsyncMock()
receiver.subscribe(on_recv=on_recv, on_error=on_error)
expected = b"hello world"
send_socket.sendto(expected, (addr, port))
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
on_recv.assert_awaited_once_with(expected)
on_error.assert_not_awaited()
await receiver.close()
assert receiver._transport is None and receiver._protocol is None
async def test_UDPReceiver_async_with(event_loop, addr, port):
async def mock_on_recv(*args):
async with cond:
cond.notify()
# To simplify things, just use a blocking socket
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
receiver = udp.UDPReceiver(local_addr=addr, port=port)
assert receiver._transport is None and receiver._protocol is None
async with receiver:
assert receiver._transport is not None and receiver._protocol is not None
cond = asyncio.Condition()
on_recv = AsyncMock(side_effect=mock_on_recv)
on_error = AsyncMock()
receiver.subscribe(on_recv=on_recv, on_error=on_error)
expected = b"hello world"
send_socket.sendto(expected, (addr, port))
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
on_recv.assert_awaited_once_with(expected)
on_error.assert_not_awaited()
assert receiver._transport is None and receiver._protocol is None
async def test_UDPReceiver_error(event_loop, addr, port):
async def mock_on_recv(*args):
await asyncio.sleep(0.01)
raise RuntimeError("yeet")
async def mock_on_error(*args):
async with cond:
cond.notify()
# To simplify things, just use a blocking socket
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
receiver = udp.UDPReceiver(local_addr=addr, port=port)
assert receiver._transport is None and receiver._protocol is None
async with receiver:
assert receiver._transport is not None and receiver._protocol is not None
cond = asyncio.Condition()
on_recv = AsyncMock(side_effect=mock_on_recv)
on_error = AsyncMock(side_effect=mock_on_error)
receiver.subscribe(on_recv=on_recv, on_error=on_error)
expected = b"hello world"
send_socket.sendto(expected, (addr, port))
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
on_recv.assert_awaited_once_with(expected)
on_error.assert_awaited_once()

View File

@@ -0,0 +1,2 @@
from .engine import NetworkEngine # noqa: F401
from .zmq_ip import ZMQReceiver, ZMQTransmitter # noqa: F401

View File

@@ -0,0 +1,165 @@
from __future__ import annotations
import asyncio
import dataclasses
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple
import orjson
import rx
from loguru import logger
from rx import operators as op
from rx.scheduler.eventloop import AsyncIOScheduler
from rx.subject import Subject
from ..types import Identity
from .types import RX, TX
class NetworkEngine:
class Disposable:
engine: NetworkEngine
rx_disposable: rx.disposable.Disposable
_disposed: bool
def __init__(self, engine: NetworkEngine, rx_disposable: rx.disposable.Disposable):
self.engine = engine
self.rx_disposable = rx_disposable
self._disposed = False
async def dispose(self) -> None:
if self._disposed:
return
self.rx_disposable.dispose()
self.engine._sub_count -= 1
if self.engine._sub_count == 0:
# Need to be careful about concurrency here. To prevent race conditions, we save/clear the rx_task
# field, cancel the saved task, and then await on it. This ensures that if someone subscribes while the
# task is being awaited it starts the rx loop back up.
task = self.engine._rx_task
self.engine._rx_task = None
# Need to check this in case the NetworkEngine instance was unbound/closed.
if task is not None:
task.cancel()
await task
_tx_factory: Callable[[], TX]
_rx: RX
_physical_tx_connections: Dict[Tuple[str, int], TX]
_bound: bool
_subject: Subject
_rx_task: Optional[asyncio.Task]
_sub_count: int
def __init__(self, rx: RX, tx_factory: Callable[[], TX]):
self._rx = rx
self._tx_factory = tx_factory
self._physical_tx_connections = {}
self._bound = False
self._subject = Subject()
self._rx_task = None
self._sub_count = 0
@property
def bound_address(self):
return self._rx.host, self._rx.port
async def bind(self, host: str, port: Optional[int] = None) -> int:
port = await self._rx.bind(host, port)
self._bound = True
return port
async def unbind(self) -> None:
if self._rx_task is not None:
self._rx_task.cancel()
await self._rx_task
self._rx_task = None
await self._rx.close()
self._bound = False
async def close(self) -> None:
keys = list(self._physical_tx_connections)
for k in keys:
tx = self._physical_tx_connections.pop(k)
await tx.close()
await self.unbind()
async def connect_to(self, host: str, port: int) -> None:
phy_conn_info = (host, port)
if phy_conn_info not in self._physical_tx_connections:
tx = self._tx_factory()
await tx.connect(host, port)
self._physical_tx_connections[phy_conn_info] = tx
async def disconnect_from(self, host: str, port: int) -> None:
phy_conn_info = (host, port)
if phy_conn_info in self._physical_tx_connections:
tx = self._physical_tx_connections.pop(phy_conn_info)
await tx.close()
async def send_to(self, id: Identity, peer: Identity, *payload: bytes) -> None:
phy_conn_info = (peer.host, peer.port)
if phy_conn_info not in self._physical_tx_connections:
raise RuntimeError(f"not connected to peer node '{peer.id}' (addr={peer.host}:{peer.port})")
tx = self._physical_tx_connections[phy_conn_info]
idbytes = orjson.dumps(dataclasses.asdict(id))
await tx.send(idbytes, *payload)
async def _rx_loop(self) -> None:
logger.debug(f"{type(self).__name__} RX loop is starting")
try:
while True:
peer_id, segments = await self._rx.recv()
try:
identity = orjson.loads(segments[0])
identity = Identity(**identity)
except Exception:
logger.exception(f"failed to extract identity from message sent by '{peer_id}'")
self._subject.on_next((True, (peer_id, segments)))
else:
self._subject.on_next((False, (identity, segments[1:])))
except asyncio.CancelledError:
logger.debug(f"{type(self).__name__} RX loop is shutting down")
pass
except Exception as e:
logger.exception(f"error in {type(self).__name__} RX loop")
self._subject.on_error(e)
def subscribe(
self,
on_recv: Callable[[Identity, Sequence[bytes]], Awaitable[None]],
on_malformed: Callable[[str, Sequence[bytes]], Awaitable[None]],
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
):
if not self._bound:
raise RuntimeError(f"{type(self).__name__} must be bound before subscribe can be called")
def handle_next(data):
malformed, payload = data
if malformed:
return rx.from_future(asyncio.create_task(on_malformed(*payload)))
else:
return rx.from_future(asyncio.create_task(on_recv(*payload)))
def handle_error(ex, src):
assert on_error is not None # nosec
return rx.from_future(asyncio.create_task(on_error(ex)))
workflow = self._subject.pipe(op.flat_map(handle_next))
if on_error is not None:
workflow = workflow.pipe(op.catch(handle_error))
d = workflow.subscribe(
on_next=lambda _: logger.trace("message processed"),
on_error=lambda e: logger.opt(exception=e).error("network engine data stream error"),
on_completed=lambda: logger.trace("network engine data stream ended"),
scheduler=AsyncIOScheduler(loop=asyncio.get_running_loop()),
)
if self._rx_task is None:
self._rx_task = asyncio.create_task(self._rx_loop())
self._sub_count += 1
return self.Disposable(self, d)

View File

@@ -0,0 +1,229 @@
import asyncio
import dataclasses
from unittest.mock import AsyncMock
import orjson
import pytest
from ..types import Identity
from .engine import NetworkEngine
# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
@pytest.fixture
def id1():
return Identity(id="1", namespace="foo", host="127.0.0.1", port=9002)
@pytest.fixture
def id2():
return Identity(id="2", namespace="bar", host="127.0.0.1", port=9002)
@pytest.fixture
def id3():
return Identity(id="3", namespace="baz", host="127.0.0.1", port=9003)
@pytest.fixture
def msg(id1):
identity = orjson.dumps(dataclasses.asdict(id1))
payload = b"hello there"
return (id1.id, [identity, payload])
@pytest.fixture
def malformed_msg():
return ("foo", [b"hello there"])
@pytest.mark.unit
async def test_NetworkEngine_connect_to(id1, id2, id3):
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
await engine.connect_to(id1.host, id1.port)
assert list(engine._physical_tx_connections) == [(id1.host, id1.port)]
tx = engine._physical_tx_connections[(id1.host, id1.port)]
tx.connect.assert_called_once_with(id1.host, id1.port)
await engine.connect_to(id2.host, id2.port)
assert list(engine._physical_tx_connections) == [(id1.host, id1.port)]
tx = engine._physical_tx_connections[(id1.host, id1.port)]
tx.connect.assert_called_once_with(id1.host, id1.port)
await engine.connect_to(id3.host, id3.port)
assert list(engine._physical_tx_connections) == [(id1.host, id1.port), (id3.host, id3.port)]
tx1 = engine._physical_tx_connections[(id1.host, id1.port)]
tx1.connect.assert_called_once_with(id1.host, id1.port)
tx2 = engine._physical_tx_connections[(id3.host, id3.port)]
tx2.connect.assert_called_once_with(id3.host, id3.port)
@pytest.mark.unit
async def test_NetworkEngine_subscribe(msg):
async def mock_recv():
await asyncio.sleep(0.01)
return msg
async def mock_on_recv(*args):
async with cond:
cond.notify()
port = 9001
rx = AsyncMock()
rx.bind.return_value = port
rx.recv.side_effect = mock_recv
engine = NetworkEngine(rx, lambda: AsyncMock())
actual_port = await engine.bind("127.0.0.1", port)
assert actual_port == port
cond = asyncio.Condition()
on_recv = AsyncMock(side_effect=mock_on_recv)
on_malformed = AsyncMock()
on_error = AsyncMock()
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
await d.dispose()
expected_id = Identity(**orjson.loads(msg[1][0]))
expected_segments = msg[1][1:]
on_recv.assert_awaited_with(expected_id, expected_segments)
on_malformed.assert_not_awaited()
on_error.assert_not_awaited()
@pytest.mark.unit
async def test_NetworkEngine_subscribe_malformed(malformed_msg):
async def mock_recv():
await asyncio.sleep(0.01)
return malformed_msg
async def mock_on_malformed(*args):
async with cond:
cond.notify()
port = 9001
rx = AsyncMock()
rx.bind.return_value = port
rx.recv.side_effect = mock_recv
engine = NetworkEngine(rx, lambda: AsyncMock())
actual_port = await engine.bind("127.0.0.1", port)
assert actual_port == port
cond = asyncio.Condition()
on_recv = AsyncMock()
on_malformed = AsyncMock(side_effect=mock_on_malformed)
on_error = AsyncMock()
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
await d.dispose()
expected_id = malformed_msg[0]
expected_segments = malformed_msg[1]
on_malformed.assert_awaited_with(expected_id, expected_segments)
on_recv.assert_not_awaited()
on_error.assert_not_awaited()
@pytest.mark.unit
async def test_NetworkEngine_subscribe_error():
async def mock_recv():
await asyncio.sleep(0.01)
raise ex
async def mock_on_error(*args):
async with cond:
cond.notify()
port = 9001
rx = AsyncMock()
rx.bind.return_value = port
rx.recv.side_effect = mock_recv
engine = NetworkEngine(rx, lambda: AsyncMock())
actual_port = await engine.bind("127.0.0.1", port)
assert actual_port == port
ex = ValueError("yeet")
cond = asyncio.Condition()
on_recv = AsyncMock()
on_malformed = AsyncMock()
on_error = AsyncMock(side_effect=mock_on_error)
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
async with cond:
await asyncio.wait_for(cond.wait(), 3.0)
await d.dispose()
on_error.assert_awaited_with(ex)
on_recv.assert_not_awaited()
on_malformed.assert_not_awaited()
@pytest.mark.unit
async def test_NetworkEngine_subscribe_not_bound():
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
with pytest.raises(RuntimeError):
engine.subscribe(on_recv=AsyncMock(), on_malformed=AsyncMock())
@pytest.mark.unit
async def test_NetworkEngine_send_to(id1, id2):
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
await engine.connect_to(id2.host, id2.port)
await engine.send_to(id1, id2, b"hello there")
tx = engine._physical_tx_connections[(id2.host, id2.port)]
tx.send.awaited_once_with(orjson.dumps(dataclasses.asdict(id1)), b"hello there")
@pytest.mark.unit
async def test_NetworkEngine_send_to_not_connected(id1, id2):
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
with pytest.raises(RuntimeError):
await engine.send_to(id1, id2, b"hello there")
@pytest.mark.unit
async def test_NetworkEngine_disconnect_from(id1, id2, id3):
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
await engine.connect_to(id1.host, id1.port)
await engine.connect_to(id2.host, id2.port)
await engine.connect_to(id3.host, id3.port)
tx12 = engine._physical_tx_connections[(id1.host, id1.port)]
tx3 = engine._physical_tx_connections[(id3.host, id3.port)]
assert list(engine._physical_tx_connections) == [(id1.host, id1.port), (id3.host, id3.port)]
await engine.disconnect_from(id2.host, id2.port)
tx12.close.assert_awaited_once()
assert list(engine._physical_tx_connections) == [(id3.host, id3.port)]
await engine.disconnect_from(id3.host, id3.port)
tx3.close.assert_awaited_once()
assert list(engine._physical_tx_connections) == []

View File

@@ -0,0 +1,39 @@
from typing import Optional, Protocol, Sequence, Tuple
class RX(Protocol):
@property
def host(self) -> Optional[str]:
...
@property
def port(self) -> Optional[int]:
...
async def bind(self, host: str, port: Optional[int] = None) -> int:
...
async def recv(self) -> Tuple[str, Sequence[bytes]]:
...
async def close(self) -> None:
...
class TX(Protocol):
@property
def host(self) -> Optional[str]:
...
@property
def port(self) -> Optional[int]:
...
async def connect(self, host: str, port: int) -> None:
...
async def send(self, *segments: bytes) -> None:
...
async def close(self) -> None:
...

View File

@@ -0,0 +1,108 @@
from typing import Optional, Sequence, Tuple
import zmq
import zmq.asyncio
class ZMQReceiver:
_context: zmq.asyncio.Context
_socket: Optional[zmq.asyncio.Socket]
_host: Optional[str]
_port: Optional[int]
def __init__(self, context: Optional[zmq.asyncio.Context] = None):
if context is None:
context = zmq.asyncio.Context.instance()
self._context = context
self._socket = None
self._host = None
self._port = None
@property
def host(self):
return self._host
@property
def port(self):
return self._port
async def bind(self, host: str, port: Optional[int] = None) -> int:
if self._socket is not None:
return self._port
self._socket = self._context.socket(zmq.ROUTER)
if port is None:
port = self._socket.bind_to_random_port(f"tcp://{host}")
else:
self._socket.bind(f"tcp://{host}:{port}")
self._host = host
self._port = port
return self._port
async def recv(self) -> Tuple[str, Sequence[bytes]]:
if self._socket is None:
raise RuntimeError("socket is unbound")
peer_id, _, *segments = await self._socket.recv_multipart()
return peer_id.decode(), segments
async def close(self) -> None:
if self._socket is None:
return
self._socket.close(linger=0)
self._socket = None
self._host = None
self._port = None
class ZMQTransmitter:
_id: str
_context: zmq.asyncio.Context
_socket: Optional[zmq.asyncio.Socket]
_host: Optional[str]
_port: Optional[int]
def __init__(self, id: str, context: Optional[zmq.asyncio.Context] = None):
if context is None:
context = zmq.asyncio.Context.instance()
self._id = id
self._context = context
self._socket = None
self._host = None
self._port = None
@property
def host(self) -> Optional[str]:
return self._host
@property
def port(self) -> Optional[int]:
return self._port
async def connect(self, host: str, port: int) -> None:
if self._socket is not None:
return
self._socket = self._context.socket(zmq.DEALER)
self._socket.setsockopt_string(zmq.IDENTITY, self._id)
self._socket.connect(f"tcp://{host}:{port}")
self._host = host
self._port = port
async def send(self, *segments: bytes) -> None:
if self._socket is None:
raise RuntimeError("socket is not connected")
await self._socket.send_multipart([b"", *segments])
async def close(self) -> None:
if self._socket is None:
return
self._socket.close(linger=0)
self._socket = None
self._host = None
self._port = None
__all__ = ["ZMQReceiver", "ZMQTransmitter"]

View File

@@ -0,0 +1,79 @@
import asyncio
import pytest
import zmq
import zmq.asyncio
from .zmq_ip import ZMQReceiver, ZMQTransmitter
# All test coroutines will be treated as marked.
pytestmark = pytest.mark.asyncio
@pytest.fixture
def addr():
return "127.0.0.1"
@pytest.fixture
def port():
return 9001
@pytest.fixture
def context():
ctx = zmq.asyncio.Context()
yield ctx
ctx.destroy(linger=0)
@pytest.fixture
def client_id():
return "foobar"
@pytest.fixture
async def dealer(context, client_id, addr, port):
socket = context.socket(zmq.DEALER)
socket.identity = client_id.encode()
socket.connect(f"tcp://{addr}:{port}")
yield socket
socket.close(linger=0)
@pytest.fixture
async def router(context, addr, port):
socket = context.socket(zmq.ROUTER)
socket.bind(f"tcp://{addr}:{port}")
yield socket
socket.close(linger=0)
async def test_ZMQReceiver_recv(context, dealer, client_id, addr, port):
rx = ZMQReceiver(context=context)
await rx.bind(addr, port)
msg1, msg2 = b"hello there", b"general kenobi"
await dealer.send_multipart([b"", msg1, msg2])
actual_id, actual_msg = await asyncio.wait_for(rx.recv(), timeout=3.0)
assert actual_id == client_id
assert actual_msg == [msg1, msg2]
await rx.close()
async def test_ZMQTransmitter_send(context, router, client_id, addr, port):
tx = ZMQTransmitter(client_id, context=context)
await tx.connect(addr, port)
msg1, msg2 = b"hello there", b"general kenobi"
await tx.send(msg1, msg2)
actual_id, _, *actual_segments = await router.recv_multipart()
assert actual_id.decode() == client_id
assert actual_segments == [msg1, msg2]

538
src/scatterbrained/node.py Normal file
View File

@@ -0,0 +1,538 @@
from __future__ import annotations
import asyncio
import dataclasses
from collections import defaultdict, deque
from enum import Enum
from typing import Callable, Deque, Dict, Optional, Sequence, Set, Tuple, Union, cast
from loguru import logger
from .discovery import DiscoveryEngine, UDPBroadcaster, UDPReceiver
from .network import NetworkEngine, ZMQReceiver, ZMQTransmitter
from .types import Identity
class OperatingMode(Enum):
"""
The mode for the Node to operate in.
Leech: Listen for broadcasts; do not share.
Offline: Do not listen or broadcast.
Peer: Listen and broadcast.
Seeding: Broadcast but do not listen.
"""
LEECHING = "leeching"
OFFLINE = "offline"
PEER = "peer"
SEEDING = "seeding"
class Namespace:
"""
A Namespace is a collection of Scatterbrained nodes that share a common model or parameter state space. When
connecting or disconnecting from a community, a Scatterbrained node joins a named namespace, or creates a new
namespace with a unique name.
Note that this class shouldn't be instantiated directly, rather instances should created via the
`scatterbrained.Node.namespace` method.
"""
node: Node
name: str
peer_filter: Callable[[Identity], bool]
_operating_mode: OperatingMode
_advertised_host: str
_advertised_port: int
_id: Identity
_mq: Deque[Tuple[Identity, Sequence[bytes]]]
_msg_cond: asyncio.Condition
_peer_cond: asyncio.Condition
def __init__(
self,
node: Node,
name: str,
operating_mode: Optional[OperatingMode] = None,
peer_filter: Optional[Callable[[Identity], bool]] = None,
advertised_host: Optional[str] = None,
advertised_port: Optional[int] = None,
hwm: int = 10_000,
):
"""
Create a new Namespace or point to an existing one.
Arguments:
node (scatterbrained.Node): The `scatterbrained.Node` to which this Namespace belongs.
name (str): The name of the `Namespace`. Must be unique if you are creating a new `Namespace`.
operating_mode (scatterbrained.OperatingMode): The operating mode in which the `Node` should operate in this
namespace. For more information on the different operating modes, see the docs on the
`scatterbrained.node.OperatingMode` enum.
peer_filter (callable): A function that takes an `Identity` and returns a boolean indicating whether the
`Node` should connect to the peer.
advertised_host (str, optional): The hostname or IP address to advertise to other Scatterbrained nodes.
advertised_port (int, optional): The port to advertise to other nodes.
hwm (int): The high water mark for the message queue. Old messages will be dropped if the queue is full
beyond this number. Defaults to 10,000.
"""
if not node.listening:
raise RuntimeError(f"{type(node).__name__} must be listening before {type(self).__name__} creation")
if advertised_host is None:
advertised_host = node.default_advertised_host
if advertised_port is None:
advertised_port = node.default_advertised_port
assert advertised_host is not None # nosec
assert advertised_port is not None # nosec
self.node = node
self.name = name
self.peer_filter = peer_filter or node.default_peer_filter
self._operating_mode = operating_mode or node.default_operating_mode
self._advertised_host = advertised_host
self._advertised_port = advertised_port
self._id = Identity(
id=self.node.id, namespace=self.name, host=self._advertised_host, port=self._advertised_port
)
self._mq = deque(maxlen=hwm)
self._msg_cond = asyncio.Condition()
self._peer_cond = asyncio.Condition()
@property
def operating_mode(self):
"""
Get the operating mode for this namespace.
Returns:
scatterbrained.OperatingMode: The operating mode for this namespace.
"""
return self._operating_mode
# NOTE: This may need to be an async setter method as opposed to a setter property.
@operating_mode.setter
def operating_mode(self, value: OperatingMode):
self._operating_mode = value
async def __aenter__(self) -> Namespace:
await self.launch()
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()
async def launch(self):
"""
Launch the Namespace.
This is usually done by entering the `Namespace` context manager. You should not need to call this method
directly.
"""
self.node.discovery_engine.add_identity(self._id)
current_peers = self.node.discovery_engine.peers
# TODO: explore exception handling approaches. Is returning the exceptions really the best way to do this?
# AFAIK, gather will cancel any ops running if an exception occurs, which would mean we'd only need to clean up.
# We should clean up regardless?
statuses = await asyncio.gather(*[self.connect_to(p) for p in current_peers], return_exceptions=True)
ex = False
connected = []
for peer, value in zip(current_peers, statuses):
if isinstance(value, Exception):
logger.bind(peer=peer).opt(exception=value).error(
f"failed to connect to {peer.id} (addr={peer.host}:{peer.port})"
)
ex = True
elif value:
connected.append(peer)
self.node._update_connections(self, connected)
if ex:
raise RuntimeError(
f"one or more exceptions occurred when attempting to launch {type(self).__name__} '{self.name}'"
)
async def close(self):
"""
Close out the Namespace gracefully.
This is usually done by exiting the `Namespace` context manager. You should not need to call this method
directly.
"""
self.node.discovery_engine.remove_identity(
Identity(id=self.node.id, namespace=self.name, host=self._advertised_host, port=self._advertised_port)
)
async def connect_to(self, peer: Identity) -> bool:
"""
Connect to a peer in this Namespace by its Identity.
The method will not connect to the given peer if the `OperatingMode` of this Namespace is either `LEECHING` or
`OFFLINE`, or the `Identity` is rejected by this instance's `peer_filter`.
Arguments:
peer (scatterbrained.Identity): The peer to connect to.
Returns:
bool: `True` if the connection was successful, `False` otherwise.
"""
if (
self.operating_mode == OperatingMode.LEECHING
or self.operating_mode == OperatingMode.OFFLINE
or not self.peer_filter(peer)
):
return False
await self.node.network_engine.connect_to(peer.host, peer.port)
return True
async def disconnect_from(self, peer: Identity) -> bool:
"""
Disconnect from a peer in this `Namespace` by its `Identity`.
Arguments:
peer (scatterbrained.Identity): The peer to disconnect from.
Returns:
bool: `True` if the disconnection was successful, `False` otherwise.
"""
return await self.node._disconnect_from(peer)
async def wait_for_peers(
self, peers: Union[int, str, Identity, Sequence[Identity], Sequence[str]], timeout: Optional[float] = None
) -> None:
"""
Wait for peers to connect to this Namespace.
This method behaves in one of three ways depending on the type of arguments passed:
* If you pass one or more `Identity` objects, this will wait for all peers to connect. Importantly, this will
match on all attributes of the `Identity` objects.
* If you pass one or more strings, this will wait for all peers with the specified ids.
* If you pass an integer, this will wait for that many peers to connect.
Arguments:
peers (Union[int, str, Identity, Sequence[Identity], Sequence[str]]): The condition to await.
timeout (float): The timeout in seconds to wait for the condition to be met.
Returns:
None
"""
if isinstance(peers, int):
return await asyncio.wait_for(self._wait_for_n_peers(peers), timeout=timeout)
elif isinstance(peers, str):
peers = [peers]
elif isinstance(peers, Identity):
peers = [peers]
strings_present = any(isinstance(v, str) for v in peers)
if strings_present:
peers = [v if isinstance(v, str) else v.id for v in peers]
await asyncio.wait_for(self._wait_for_named_peers(peers), timeout=timeout)
else:
peers = cast(Sequence[Identity], peers)
await asyncio.wait_for(self._wait_for_exact_peers(peers), timeout=timeout)
async def _wait_for_n_peers(self, num_peers: int) -> None:
def check_fn():
return len(self.node._peers_by_namespace[self.name]) >= num_peers
if not check_fn():
async with self._peer_cond:
await self._peer_cond.wait_for(check_fn)
async def _wait_for_named_peers(self, peers: Sequence[str]) -> None:
peer_set = set(peers)
def check_fn():
return len(peer_set - set(v.id for v in self.node._peers_by_namespace[self.name])) == 0
if not check_fn():
async with self._peer_cond:
await self._peer_cond.wait_for(check_fn)
async def _wait_for_exact_peers(self, peers: Sequence[Identity]) -> None:
peer_set = set(peers)
def check_fn():
return len(peer_set - self.node._peers_by_namespace[self.name]) == 0
if not check_fn():
async with self._peer_cond:
await self._peer_cond.wait_for(check_fn)
async def send_to(self, peer: Identity, *payload: bytes) -> None:
"""
Send a byte sequence payload to a peer by its `Identity`.
Arguments:
peer (scatterbrained.Identity): The peer to send to.
payload (bytes): The payload to send.
Returns:
None
"""
await self.node.network_engine.send_to(self._id, peer, *payload)
async def recv(self, timeout: Optional[float] = None) -> Tuple[Identity, Sequence[bytes]]:
"""
Receive a message from the network.
This will block until a message is received, or the timeout is reached.
Arguments:
timeout (float): The timeout in seconds to wait for a message to be received.
Returns:
Tuple[scatterbrained.Identity, bytes]: The sender's Identity and the payload.
"""
if not len(self._mq):
async with self._msg_cond:
await asyncio.wait_for(self._msg_cond.wait(), timeout=timeout)
return self._mq.pop()
async def _on_appear(self, peer: Identity):
async with self._peer_cond:
self._peer_cond.notify_all()
async def _add_message_from(self, peer: Identity, payload: Sequence[bytes]) -> None:
empty = len(self._mq) == 0
self._mq.append((peer, payload))
if empty:
async with self._msg_cond:
self._msg_cond.notify()
class Node:
"""
Scatterbrained peer node.
Manages all networking infrastructure, providing an interface for other classes to establish connections, and send
and receive data.
"""
id: str
discovery_engine: DiscoveryEngine
network_engine: NetworkEngine
default_operating_mode: OperatingMode
default_peer_filter: Callable[[Identity], bool]
default_advertised_host: Optional[str]
default_advertised_port: Optional[int]
_host: str
_port: Optional[int]
_namespaces: Dict[str, Namespace]
_virtual_tx_connections: Dict[Tuple[str, int], Set[Identity]]
_peers_by_namespace: Dict[str, Set[Identity]]
_network_sub: Optional[NetworkEngine.Disposable]
def __init__(
self,
id: str,
host: str = "0.0.0.0", # nosec
port: Optional[int] = None,
discovery_engine: Optional[DiscoveryEngine] = None,
network_engine: Optional[NetworkEngine] = None,
default_operating_mode: OperatingMode = OperatingMode.PEER,
default_peer_filter: Optional[Callable[[Identity], bool]] = None,
default_advertised_host: Optional[str] = None,
default_advertised_port: Optional[int] = None,
) -> None:
"""
Create a new Scatterbrained peer node.
Arguments:
id (str): The identity of this `Node`. Must be provided.
host (str): The host to bind to. Defaults to bind on all IP addresses on the machine (`0.0.0.0`).
port (int): The port to bind to. Defaults to an arbitrary open port.
discovery_engine (scatterbrained.DiscoveryEngine): The `DiscoveryEngine` to use. Defaults to a new
`DiscoveryEngine` with a default configuration, using UDP broadcast.
network_engine (scatterbrained.NetworkEngine): The `NetworkEngine` to use. Defaults to a new NetworkEngine
with a default configuration, using ZMQ.
default_operating_mode (scatterbrained.OperatingMode): The default operating mode for new `Namespace`s.
Defaults to `OperatingMode.PEER`.
default_peer_filter (Callable[[scatterbrained.Identity], bool]): The default peer filter for new
`Namespace`s. The default behavior is for all peers to be accepted.
default_advertised_host (str): The advertised host to use for new `Namespace`s. Defaults to the host
provided.
default_advertised_port (int): The advertised port to use for new `Namespace`s. Defaults to the port
provided.
"""
if discovery_engine is None:
discovery_engine = DiscoveryEngine(publisher=UDPBroadcaster(), subscriber=UDPReceiver())
if network_engine is None:
network_engine = NetworkEngine(rx=ZMQReceiver(), tx_factory=lambda: ZMQTransmitter(id=id))
if default_peer_filter is None:
default_peer_filter = lambda _: True # noqa: E731
self.id = id
self.discovery_engine = discovery_engine
self.network_engine = network_engine
self.default_peer_filter = default_peer_filter
self.default_operating_mode = default_operating_mode
self.default_advertised_host = default_advertised_host
self.default_advertised_port = default_advertised_port
self._host = host
self._port = port
self._namespaces = {}
self._virtual_tx_connections = defaultdict(set)
self._peers_by_namespace = defaultdict(set)
self._network_sub = None
@property
def listening(self):
"""
Whether or not the `Node` is listening for incoming connections.
Returns:
bool: Whether or not the Node is listening for incoming connections.
"""
return self._network_sub is not None
async def __aenter__(self) -> Node:
await self.launch()
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
async def launch(self) -> None:
"""
Launch the `Node`.
This will start the `DiscoveryEngine` and `NetworkEngine` and begin listening for new connections.
Note that this method is usually called automatically when entering an `async with` block and is not meant to
be called manually.
Returns:
None
"""
if self._network_sub is not None:
return
await self.network_engine.bind(host=self._host, port=self._port)
self._network_sub = self.network_engine.subscribe(
on_recv=self._on_recv, on_malformed=self._on_malformed, on_error=self._on_network_error
)
await self.discovery_engine.start(
on_appear=self._on_appear, on_disappear=self._on_disappear, on_error=self._on_discovery_error
)
bound_host, bound_port = self.network_engine.bound_address
assert bound_host is not None, "expected bound host to not be None" # nosec
assert bound_port is not None, "expected bound port to not be None" # nosec
if self.default_advertised_host is None:
self.default_advertised_host = bound_host
if self.default_advertised_port is None:
self.default_advertised_port = bound_port
async def close(self):
"""
Gracefully close the `Node`.
This will close the `DiscoveryEngine` and `NetworkEngine`.
Note that this method is usually called automatically when exiting an `async with` block and is not meant to be
called manually.
Returns:
None
"""
if self._network_sub is None:
return
await self.discovery_engine.stop()
await self._network_sub.dispose()
await self.network_engine.close()
self._network_sub = None
def namespace(self, name: str, *args, **kwargs) -> Namespace:
"""
Gets the namespace with the given name, or creates a new one.
Also accepts the same arguments as `scatterbrained.Namespace` if creating a new namespace.
Arguments:
name (str): The name of the namespace.
Returns:
scatterbrained.Namespace: The namespace with the given name.
"""
if (ns := self._namespaces.get(name)) is not None:
return ns
ns = Namespace(self, name, *args, **kwargs)
self._namespaces[ns.name] = ns
return ns
def _update_connections(self, ns: Namespace, peers: Sequence[Identity]) -> None:
for peer in peers:
self._virtual_tx_connections[(peer.host, peer.port)].add(peer)
async def _on_appear(self, peer: Identity) -> None:
cl = logger.bind(peer=dataclasses.asdict(peer))
cl.debug(f"peer '{peer.id}' is online")
ns = self._namespaces.get(peer.namespace)
if ns is not None:
connected = await ns.connect_to(peer)
if connected:
cl.debug(f"connected to '{peer.id}'")
self._virtual_tx_connections[(peer.host, peer.port)].add(peer)
self._peers_by_namespace[ns.name].add(peer)
await ns._on_appear(peer)
async def _on_disappear(self, peer: Identity) -> None:
await self._disconnect_from(peer)
async def _on_discovery_error(self, ex: Exception) -> None:
logger.opt(exception=ex).error("discovery engine encountered an error")
# TODO: determine how to best alert the user of this error.
async def _on_recv(self, peer: Identity, payload: Sequence[bytes]) -> None:
ns = self._namespaces.get(peer.namespace)
if ns is not None:
await ns._add_message_from(peer, payload)
async def _on_malformed(self, peer_id: str, segments: Sequence[bytes]) -> None:
logger.warning(f"malformed message consisting of {len(segments)} byte segments from '{peer_id}'")
# TODO: anything else to do here? Maybe allow segments to be saved.
async def _on_network_error(self, ex: Exception) -> None:
logger.opt(exception=ex).error("network engine encountered an error")
# TODO: determine how to best alert the user of this error.
async def _disconnect_from(self, peer: Identity) -> bool:
cl = logger.bind(peer=dataclasses.asdict(peer))
cl.debug(f"peer '{peer.id}' is offline")
virt_conns = self._virtual_tx_connections[(peer.host, peer.port)]
if peer not in virt_conns:
cl.warning(f"attempted to disconnect from peer '{peer.id}' that was never connected to originally")
return False
virt_conns.remove(peer)
if not virt_conns:
await self.network_engine.disconnect_from(peer.host, peer.port)
cl.debug(f"disconnected from '{peer.host}:{peer.port}'")
self._peers_by_namespace[peer.namespace].remove(peer)
return True
__all__ = ["Node"]

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
@dataclass(frozen=True)
class Identity:
"""
Represents the identity of a scatterbrained.Node (either local or remote) for a particular namespace.
Args:
id (str): The id of the Node.
namespace (str): The namespace the Node is operating in.
host (str): The advertised address of this Node.
port (int): The advertised port of this Node.
"""
id: str
namespace: str
host: str
port: int
__all__ = ["Identity"]

View File

@@ -0,0 +1 @@
__version__ = "0.0.1"