mirror of
https://github.com/JHUAPL/scatterbrained.git
synced 2026-01-08 21:47:57 -05:00
🎉 Open-source Scatterbrained
This commit is contained in:
5
.flake8
Normal file
5
.flake8
Normal file
@@ -0,0 +1,5 @@
|
||||
[flake8]
|
||||
max-line-length = 120
|
||||
extend-ignore = E203,E501,W503
|
||||
max-complexity = 12
|
||||
select = B,C,E,F,W,B9
|
||||
427
.gitignore
vendored
Normal file
427
.gitignore
vendored
Normal file
@@ -0,0 +1,427 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm
|
||||
|
||||
### Emacs ###
|
||||
# -*- mode: gitignore; -*-
|
||||
*~
|
||||
\#*\#
|
||||
/.emacs.desktop
|
||||
/.emacs.desktop.lock
|
||||
*.elc
|
||||
auto-save-list
|
||||
tramp
|
||||
.\#*
|
||||
|
||||
# Org-mode
|
||||
.org-id-locations
|
||||
*_archive
|
||||
ltximg/**
|
||||
|
||||
# flymake-mode
|
||||
*_flymake.*
|
||||
|
||||
# eshell files
|
||||
/eshell/history
|
||||
/eshell/lastdir
|
||||
|
||||
# elpa packages
|
||||
/elpa/
|
||||
|
||||
# reftex files
|
||||
*.rel
|
||||
|
||||
# AUCTeX auto folder
|
||||
/auto/
|
||||
|
||||
# cask packages
|
||||
.cask/
|
||||
dist/
|
||||
|
||||
# Flycheck
|
||||
flycheck_*.el
|
||||
|
||||
# server auth directory
|
||||
/server/
|
||||
|
||||
# projectiles files
|
||||
.projectile
|
||||
|
||||
# directory configuration
|
||||
.dir-locals.el
|
||||
|
||||
# network security
|
||||
/network-security.data
|
||||
|
||||
|
||||
### Linux ###
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### PyCharm ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### PyCharm Patch ###
|
||||
# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
|
||||
|
||||
# *.iml
|
||||
# modules.xml
|
||||
# .idea/misc.xml
|
||||
# *.ipr
|
||||
|
||||
# Sonarlint plugin
|
||||
# https://plugins.jetbrains.com/plugin/7973-sonarlint
|
||||
.idea/**/sonarlint/
|
||||
|
||||
# SonarQube Plugin
|
||||
# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
|
||||
.idea/**/sonarIssues.xml
|
||||
|
||||
# Markdown Navigator plugin
|
||||
# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
|
||||
.idea/**/markdown-navigator.xml
|
||||
.idea/**/markdown-navigator-enh.xml
|
||||
.idea/**/markdown-navigator/
|
||||
|
||||
# Cache file creation bug
|
||||
# See https://youtrack.jetbrains.com/issue/JBR-2257
|
||||
.idea/$CACHE_FILE$
|
||||
|
||||
# CodeStream plugin
|
||||
# https://plugins.jetbrains.com/plugin/12206-codestream
|
||||
.idea/codestream.xml
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
pytestdebug.log
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
doc/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
pythonenv*
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# profiling data
|
||||
.prof
|
||||
|
||||
### SublimeText ###
|
||||
# Cache files for Sublime Text
|
||||
*.tmlanguage.cache
|
||||
*.tmPreferences.cache
|
||||
*.stTheme.cache
|
||||
|
||||
# Workspace files are user-specific
|
||||
*.sublime-workspace
|
||||
|
||||
# Project files should be checked into the repository, unless a significant
|
||||
# proportion of contributors will probably not be using Sublime Text
|
||||
# *.sublime-project
|
||||
|
||||
# SFTP configuration file
|
||||
sftp-config.json
|
||||
|
||||
# Package control specific files
|
||||
Package Control.last-run
|
||||
Package Control.ca-list
|
||||
Package Control.ca-bundle
|
||||
Package Control.system-ca-bundle
|
||||
Package Control.cache/
|
||||
Package Control.ca-certs/
|
||||
Package Control.merged-ca-bundle
|
||||
Package Control.user-ca-bundle
|
||||
oscrypto-ca-bundle.crt
|
||||
bh_unicode_properties.cache
|
||||
|
||||
# Sublime-github package stores a github token in this file
|
||||
# https://packagecontrol.io/packages/sublime-github
|
||||
GitHub.sublime-settings
|
||||
|
||||
### Vim ###
|
||||
# Swap
|
||||
[._]*.s[a-v][a-z]
|
||||
!*.svg # comment out if you don't need vector files
|
||||
[._]*.sw[a-p]
|
||||
[._]s[a-rt-v][a-z]
|
||||
[._]ss[a-gi-z]
|
||||
[._]sw[a-p]
|
||||
|
||||
# Session
|
||||
Session.vim
|
||||
Sessionx.vim
|
||||
|
||||
# Temporary
|
||||
.netrwhist
|
||||
# Auto-generated tag files
|
||||
tags
|
||||
# Persistent undo
|
||||
[._]*.un~
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/
|
||||
|
||||
### VisualStudioCode Patch ###
|
||||
# Ignore all local history of files
|
||||
.history
|
||||
.ionide
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,vim,emacs,sublimetext,visualstudiocode,pycharm
|
||||
38
.pre-commit-config.yaml
Normal file
38
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-json
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.9.2
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 21.7b0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
- repo: https://gitlab.com/pycqa/flake8
|
||||
rev: 3.9.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
additional_dependencies: [flake8-bugbear]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.0
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: ["-x", "*/**/*_test.py"]
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
name: pytest
|
||||
entry: pytest src/scatterbrained
|
||||
language: system
|
||||
pass_filenames: false
|
||||
types: [python]
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2021 The Johns Hopkins Applied Physics Laboratory
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
62
README.md
Normal file
62
README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
|
||||
<h1 align='center'>Scatterbrained</h1>
|
||||
<p align='center'>Decentralized Federated Learning</p>
|
||||
<p align='center'>
|
||||
<a href="https://pypi.org/project/scatterbrained/"><img alt="PyPI - Python Version" src="https://img.shields.io/pypi/pyversions/scatterbrained?style=for-the-badge"></a>
|
||||
<a href="https://github.com/JHUAPL/scatterbrained"><img alt="GitHub last commit" src="https://img.shields.io/github/last-commit/JHUAPL/scatterbrained?style=for-the-badge"></a>
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0"><img alt="GitHub" src="https://img.shields.io/github/license/JHUAPL/scatterbrained?style=for-the-badge"></a>
|
||||
</p>
|
||||
|
||||
|
||||
Scatterbrained makes it easy to build federated learning systems. In addition to traditional federated learning, Scatterbrained supports decentralized federated learning — a new, cooperative type of federated learning where the learning is done by a group of peers instead of by a centralized server. For more information, see our 2021 paper, [_Scatterbrained: A flexible and expandable pattern for decentralized machine learning_](#).
|
||||
|
||||
You can use your favorite machine learning frameworks alongside Scatterbrained, such as TensorFlow, SciKit-Learn, or PyTorch.
|
||||
|
||||
## Usage
|
||||
|
||||
For examples of how to get started using Scatterbrained, see the [Examples](examples/) directory.
|
||||
|
||||
## Installation
|
||||
|
||||
You can install Scatterbrained with pip:
|
||||
|
||||
```shell
|
||||
pip install scatterbrained
|
||||
```
|
||||
|
||||
If you would rather download and install from source, you can do so with the following:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/JHUAPL/scatterbrained.git
|
||||
cd scatterbrained
|
||||
```
|
||||
|
||||
You must first install the dependencies with:
|
||||
|
||||
```shell
|
||||
pip3 install -r ./requirements/requirements.txt
|
||||
```
|
||||
|
||||
And then you can install the package with:
|
||||
|
||||
```shell
|
||||
pip3 install -e .
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
The code in this repository is released under an Apache 2.0 license. For more information, see [LICENSE](LICENSE).
|
||||
|
||||
> Copyright 2021 The Johns Hopkins Applied Physics Laboratory
|
||||
>
|
||||
> Licensed under the Apache License, Version 2.0 (the "License");
|
||||
> you may not use this file except in compliance with the License.
|
||||
> You may obtain a copy of the License at
|
||||
>
|
||||
> http://www.apache.org/licenses/LICENSE-2.0
|
||||
>
|
||||
> Unless required by applicable law or agreed to in writing, software
|
||||
> distributed under the License is distributed on an "AS IS" BASIS,
|
||||
> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
> See the License for the specific language governing permissions and
|
||||
> limitations under the License.
|
||||
44
docs/Installation.md
Normal file
44
docs/Installation.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Installation Guide
|
||||
|
||||
Installing Scatterbrained on your development machine can be done either through pip or by downloading the source code and installing it locally.
|
||||
|
||||
## Pip-based Installation
|
||||
|
||||
```shell
|
||||
pip install scatterbrained
|
||||
```
|
||||
|
||||
This will automatically install the latest stable version of Scatterbrained, along with its dependencies. For a complete list of dependencies, see the `requirements/` directory in this repository.
|
||||
|
||||
## Source-based Installation
|
||||
|
||||
You can also clone the git repository and install Scatterbrained from source. In this case, you will also need to manually install the dependencies.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/JHUAPL/scatterbrained.git
|
||||
cd scatterbrained
|
||||
|
||||
# Install dependencies:
|
||||
pip install -r requirements/requirements.txt
|
||||
|
||||
# Install the library in "editable"/development mode:
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Container-based Installation
|
||||
|
||||
It is also possible to install and Scatterbrained in a Docker container. This can be useful for running multiple instances of Scatterbrained on the same machine, for testing or benchmarking.
|
||||
|
||||
Docker-based instructions coming soon.
|
||||
|
||||
## Installation Troubleshooting
|
||||
|
||||
Click the carrot next to each question to see the full troubleshooting guide.
|
||||
|
||||
<details>
|
||||
<summary><b>Errors when installing on Ubuntu ≤14</b></summary>
|
||||
|
||||
This is likely due to an older version of 0MQ installed with `apt-get`. If you're sure you don't have software that relies upon this older version, then run the following commands to remove the old version of 0MQ and install the latest version of 0MQ:
|
||||
|
||||
sudo apt-get remove libzmq-dev python-zmq
|
||||
</details>
|
||||
112
docs/getting-started/README.md
Normal file
112
docs/getting-started/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# Getting Started with Scatterbrained
|
||||
|
||||
This is a quickstart tutorial for the Scatterbrained federated learning library. It will assume that you have already installed the library and have a working, high-level understanding of federated learning.
|
||||
|
||||
* For installation instructions, see [Installation](../installation/README.md).
|
||||
* For a refresher on federated learning, see this [Wikipedia article](https://en.wikipedia.org/wiki/Federated_learning).
|
||||
|
||||
## Step 1: Building a model
|
||||
|
||||
For this example, we're going to build a basic linear regression model using `torch`. (You can also use other machine learning frameworks, such as TensorFlow or sklearn!) We'll load a simple built-in dataset, and train a linear regression model on it.
|
||||
|
||||
|
||||
```python
|
||||
import sklearn.datasets
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
|
||||
# Load the data
|
||||
X, y = sklearn.datasets.load_diabetes(return_X_y=True)
|
||||
|
||||
# Generate train/test splits
|
||||
(X_train, X_val, y_train, y_val) = train_test_split(
|
||||
X, y, test_size=0.2
|
||||
)
|
||||
|
||||
# Create the data loaders for torch
|
||||
train = DataLoader(
|
||||
TensorDataset(
|
||||
torch.from_numpy(X_train), torch.from_numpy(y_train)
|
||||
),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
val = DataLoader(
|
||||
TensorDataset(
|
||||
torch.from_numpy(X_val), torch.from_numpy(y_val)
|
||||
),
|
||||
batch_size=32, shuffle=True
|
||||
)
|
||||
|
||||
|
||||
# Create the model
|
||||
model = nn.Linear(X_train.shape[1], 1)
|
||||
create_optimizer = lambda model: optim.AdamW(model.parameters())
|
||||
```
|
||||
|
||||
In the codeblock above, we've created a model just like we would in a traditional ML pipeline. We've also created a function that returns an optimizer. So far, we're not doing anything "FL-flavored": This is just plain old ML.
|
||||
|
||||
## Step 2: Training the model
|
||||
|
||||
It's now time to introduce the Scatterbrained library. Training a model is still quite simple:
|
||||
|
||||
```python
|
||||
import scatterbrained as sb
|
||||
|
||||
NUM_EPOCHS = 10
|
||||
|
||||
# Create a new scatterbrained Node to hold all of the logic
|
||||
# for a federated learning compute node:
|
||||
async with sb.Node() as node:
|
||||
|
||||
# Create a new Namespace. This is a unique identifier
|
||||
# shared by all nodes in the same FL community. Other
|
||||
# nodes on the same network that know this name can join
|
||||
# your FL cluster:
|
||||
|
||||
async with node.namespace(
|
||||
"MyFirstNamespace", model, create_optimizer
|
||||
) as ns:
|
||||
|
||||
# Finally we can ask scatterbrained to perform the
|
||||
# training loop for us in a background thread:
|
||||
await ns.train(NUM_EPOCHS, trainloader, validloader)
|
||||
|
||||
# At the same time, we can also serve our node's
|
||||
# resources to other nodes on the network:
|
||||
await ns.serve()
|
||||
|
||||
```
|
||||
|
||||
## Step 3: Joining the cluster
|
||||
|
||||
From another machine on the same network (or the same machine on a different port), you can join the cluster by running the following code.
|
||||
|
||||
Unlike the code blocks above, where we built up a file gradually, this is all the code you need to run from your second machine:
|
||||
|
||||
```python
|
||||
import scatterbrained as sb
|
||||
|
||||
async with sb.Node() as node:
|
||||
async with node.join("MyFirstNamespace") as ns:
|
||||
await ns.sync_model() # This line is different!
|
||||
await ns.serve()
|
||||
```
|
||||
|
||||
Wow, that's succinct! Let's take a look at what's happening behind the scenes when we run this code.
|
||||
|
||||
First, we create a new scatterbrained Node, just like before. And just like in the first example, we create a new Namespace with the same name, so that the nodes know they're allowed to talk to each other. (A sb.Node can't cooperatively learn with a Node training a different model, so we use the Namespace to indicate to the Node that this networked peer is speaking the same language — i.e., using the same model.)
|
||||
|
||||
But here we don't specify any training data or model architecture: Instead, we use the `sync_model` method so that this node can download the model from another peer Node on the network. This means that you can use Scatterbrained to quickly transit a model from one machine to another. (In other words, your models will always be in sync, with a single node serving as the source of truth.)
|
||||
|
||||
## Next Steps
|
||||
|
||||
In this tutorial, we've covered the basics of using Scatterbrained to train a model. But there are many other features that you can use to tailor your decentralized federated learning code. For example,
|
||||
|
||||
* You can emulate traditional, centralized federated learning if you want a single machine to serve as an authority
|
||||
* You can train a model on multiple machines in parallel with different datasets
|
||||
* You can specify different optimizers for different Nodes
|
||||
* You can change your network topology so that nodes can communicate with only certain peers, and will ignore others
|
||||
* You can design custom network topologies so that information can only flow in certain directions
|
||||
12
docs/reference/README.md
Normal file
12
docs/reference/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Reference documentation for the Scatterbrained Library
|
||||
|
||||
<!--
|
||||
|
||||
Documentation was generated with `docshund`:
|
||||
|
||||
```
|
||||
$ pip install docshund
|
||||
$ docshund ./src
|
||||
```
|
||||
|
||||
-->
|
||||
0
docs/reference/discovery/discovery.md
Normal file
0
docs/reference/discovery/discovery.md
Normal file
74
docs/reference/discovery/engine.py.md
Normal file
74
docs/reference/discovery/engine.py.md
Normal file
@@ -0,0 +1,74 @@
|
||||
## *Class* `DiscoveryEngine`
|
||||
|
||||
|
||||
A `DiscoveryEngine` is a manager class for the process of identifying and communicating with other peers in the Scatterbrained network.
|
||||
|
||||
The `DiscoveryEngine` is responsible for maintaining a list of peers, and periodically sending heartbeats to the network, which other peers can use to determine if the sending node is still alive.
|
||||
|
||||
Identities of peers are stored in `DiscoveryEngine.peers`.
|
||||
|
||||
Identities of the current node are stored in `DiscoveryEngine.identities` and can be manually added or removed using `DiscoveryEngine.add_identity` and `DiscoveryEngine.remove_identity`.
|
||||
|
||||
|
||||
## *Function* `heartbeat(self)`
|
||||
|
||||
|
||||
Retrieve the heartbeat interval in seconds.
|
||||
|
||||
|
||||
## *Function* `heartbeat(self, value)`
|
||||
|
||||
|
||||
Set the heartbeat interval in seconds.
|
||||
|
||||
### Arguments
|
||||
> - **value** (`float`: `None`): The heartbeat interval in seconds.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `peers(self)`
|
||||
|
||||
|
||||
Retrieve the list of peers.
|
||||
|
||||
### Returns
|
||||
> - **Set[scatterbrained.discovery.types.Identity]** (`None`: `None`): The list of peers.
|
||||
|
||||
|
||||
|
||||
## *Function* `add_identity(self, identity: Identity) -> None`
|
||||
|
||||
|
||||
Add an identity to the list of identities for the current node.
|
||||
|
||||
### Arguments
|
||||
> - **identity** (`scatterbrained.discovery.types.Identity`: `None`): The identity to add.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `remove_identity(self, identity: Identity) -> None`
|
||||
|
||||
|
||||
Remove an identity from the list of valid identities for the node.
|
||||
|
||||
### Arguments
|
||||
> - **identity** (`scatterbrained.discovery.types.Identity`: `None`): The identity to remove.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `stop(self)`
|
||||
|
||||
|
||||
Gracefully stop the `DiscoveryEngine`.
|
||||
|
||||
### Returns
|
||||
None
|
||||
46
docs/reference/discovery/types.py.md
Normal file
46
docs/reference/discovery/types.py.md
Normal file
@@ -0,0 +1,46 @@
|
||||
## *Class* `Publisher(Protocol)`
|
||||
|
||||
|
||||
A Publisher is a class that can publish a message to a set of peers.
|
||||
|
||||
Note that this is a protocol and shouldn't be used directly!
|
||||
|
||||
|
||||
|
||||
## *Function* `publish(self, data: bytes) -> None`
|
||||
|
||||
|
||||
Publish the given payload to a set of peers.
|
||||
|
||||
|
||||
## *Function* `open(self) -> None`
|
||||
|
||||
|
||||
Open the underlying connection mechanism, enabling this instance to send messages.
|
||||
|
||||
|
||||
## *Function* `close(self) -> None`
|
||||
|
||||
|
||||
Close the underlying connection mechanism, stopping this instance from sending messages.
|
||||
|
||||
|
||||
## *Class* `Subscriber(Protocol)`
|
||||
|
||||
|
||||
A Subscriber is a class that can subscribe to messages from a set of peers.
|
||||
|
||||
Note that this is a protocol and shouldn't be used directly!
|
||||
|
||||
|
||||
|
||||
## *Function* `open(self) -> None`
|
||||
|
||||
|
||||
Open the underlying connection mechanism, enabling this instance to receive messages.
|
||||
|
||||
|
||||
## *Function* `close(self) -> None`
|
||||
|
||||
|
||||
Close the underlying connection mechanism, stopping this instance from receiving messages.
|
||||
0
docs/reference/discovery/udp.py.md
Normal file
0
docs/reference/discovery/udp.py.md
Normal file
0
docs/reference/discovery/udp_test.py.md
Normal file
0
docs/reference/discovery/udp_test.py.md
Normal file
0
docs/reference/network/engine.py.md
Normal file
0
docs/reference/network/engine.py.md
Normal file
0
docs/reference/network/engine_test.py.md
Normal file
0
docs/reference/network/engine_test.py.md
Normal file
0
docs/reference/network/network.md
Normal file
0
docs/reference/network/network.md
Normal file
0
docs/reference/network/types.py.md
Normal file
0
docs/reference/network/types.py.md
Normal file
0
docs/reference/network/zmq_ip.py.md
Normal file
0
docs/reference/network/zmq_ip.py.md
Normal file
0
docs/reference/network/zmq_ip_test.py.md
Normal file
0
docs/reference/network/zmq_ip_test.py.md
Normal file
164
docs/reference/scatterbrained/node.py.md
Normal file
164
docs/reference/scatterbrained/node.py.md
Normal file
@@ -0,0 +1,164 @@
|
||||
## *Class* `OperatingMode(Enum)`
|
||||
|
||||
|
||||
The mode for the Node to operate in.
|
||||
|
||||
> - **Leech** (`None`: `None`): Listen for broadcasts; do not share.
|
||||
> - **Offline** (`None`: `None`): Do not listen or broadcast.
|
||||
> - **Peer** (`None`: `None`): Listen and broadcast.
|
||||
> - **Seeding** (`None`: `None`): Broadcast but do not listen.
|
||||
|
||||
|
||||
|
||||
## *Class* `Namespace`
|
||||
|
||||
|
||||
A Namespace is a collection of Scatterbrained nodes that share a common model or parameter state space. When connecting or disconnecting from a community, a Scatterbrained node joins a named namespace, or creates a new namespace with a unique name.
|
||||
|
||||
Note that this class shouldn't be instantiated directly, rather instances should created via the `scatterbrained.Node.namespace` method.
|
||||
|
||||
|
||||
|
||||
## *Function* `operating_mode(self)`
|
||||
|
||||
|
||||
Get the operating mode for this namespace.
|
||||
|
||||
### Returns
|
||||
> - **scatterbrained.OperatingMode** (`None`: `None`): The operating mode for this namespace.
|
||||
|
||||
|
||||
|
||||
## *Function* `launch(self)`
|
||||
|
||||
|
||||
Launch the Namespace.
|
||||
|
||||
This is usually done by entering the `Namespace` context manager. You should not need to call this method directly.
|
||||
|
||||
|
||||
|
||||
## *Function* `close(self)`
|
||||
|
||||
|
||||
Close out the Namespace gracefully.
|
||||
|
||||
This is usually done by exiting the `Namespace` context manager. You should not need to call this method directly.
|
||||
|
||||
|
||||
|
||||
## *Function* `connect_to(self, peer: Identity) -> bool`
|
||||
|
||||
|
||||
Connect to a peer in this Namespace by its Identity.
|
||||
|
||||
The method will not connect to the given peer if the `OperatingMode` of this Namespace is either `LEECHING` or `OFFLINE`, or the `Identity` is rejected by this instance's `peer_filter`.
|
||||
|
||||
### Arguments
|
||||
> - **peer** (`scatterbrained.Identity`: `None`): The peer to connect to.
|
||||
|
||||
### Returns
|
||||
> - **bool** (`None`: `None`): `True` if the connection was successful, `False` otherwise.
|
||||
|
||||
|
||||
|
||||
## *Function* `disconnect_from(self, peer: Identity) -> bool`
|
||||
|
||||
|
||||
Disconnect from a peer in this `Namespace` by its `Identity`.
|
||||
|
||||
### Arguments
|
||||
> - **peer** (`scatterbrained.Identity`: `None`): The peer to disconnect from.
|
||||
|
||||
### Returns
|
||||
> - **bool** (`None`: `None`): `True` if the disconnection was successful, `False` otherwise.
|
||||
|
||||
|
||||
|
||||
## *Function* `send_to(self, peer: Identity, *payload: bytes) -> None`
|
||||
|
||||
|
||||
Send a byte sequence payload to a peer by its `Identity`.
|
||||
|
||||
### Arguments
|
||||
> - **peer** (`scatterbrained.Identity`: `None`): The peer to send to.
|
||||
> - **payload** (`bytes`: `None`): The payload to send.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `recv(self, timeout: Optional[float] = None) -> Tuple[Identity, Sequence[bytes]]`
|
||||
|
||||
|
||||
Receive a message from the network.
|
||||
|
||||
This will block until a message is received, or the timeout is reached.
|
||||
|
||||
### Arguments
|
||||
> - **timeout** (`float`: `None`): The timeout in seconds to wait for a message to be received.
|
||||
|
||||
### Returns
|
||||
> - **bytes]** (`None`: `None`): The sender's Identity and the payload.
|
||||
|
||||
|
||||
|
||||
## *Class* `Node`
|
||||
|
||||
|
||||
Scatterbrained peer node.
|
||||
|
||||
Manages all networking infrastructure, providing an interface for other classes to establish connections, and send and receive data.
|
||||
|
||||
|
||||
## *Function* `listening(self)`
|
||||
|
||||
|
||||
Whether or not the `Node` is listening for incoming connections.
|
||||
|
||||
### Returns
|
||||
> - **bool** (`None`: `None`): Whether or not the Node is listening for incoming connections.
|
||||
|
||||
|
||||
|
||||
## *Function* `launch(self) -> None`
|
||||
|
||||
|
||||
Launch the `Node`.
|
||||
|
||||
This will start the `DiscoveryEngine` and `NetworkEngine` and begin listening for new connections.
|
||||
|
||||
Note that this method is usually called automatically when entering an `async with` block and is not meant to be called manually.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `close(self)`
|
||||
|
||||
|
||||
Gracefully close the `Node`.
|
||||
|
||||
This will close the `DiscoveryEngine` and `NetworkEngine`.
|
||||
|
||||
Note that this method is usually called automatically when exiting an `async with` block and is not meant to be called manually.
|
||||
|
||||
### Returns
|
||||
None
|
||||
|
||||
|
||||
|
||||
## *Function* `namespace(self, name: str, *args, **kwargs) -> Namespace`
|
||||
|
||||
|
||||
Gets the namespace with the given name, or creates a new one.
|
||||
|
||||
Also accepts the same arguments as `scatterbrained.Namespace` if creating a new namespace.
|
||||
|
||||
### Arguments
|
||||
> - **name** (`str`: `None`): The name of the namespace.
|
||||
|
||||
### Returns
|
||||
> - **scatterbrained.Namespace** (`None`: `None`): The namespace with the given name.
|
||||
0
docs/reference/scatterbrained/scatterbrained.md
Normal file
0
docs/reference/scatterbrained/scatterbrained.md
Normal file
10
docs/reference/scatterbrained/types.py.md
Normal file
10
docs/reference/scatterbrained/types.py.md
Normal file
@@ -0,0 +1,10 @@
|
||||
## *Class* `Identity`
|
||||
|
||||
|
||||
Represents the identity of a scatterbrained.Node (either local or remote) for a particular namespace.
|
||||
|
||||
### Args
|
||||
> - **id** (`str`: `None`): The id of the Node.
|
||||
> - **namespace** (`str`: `None`): The namespace the Node is operating in.
|
||||
> - **host** (`str`: `None`): The advertised address of this Node.
|
||||
> - **port** (`int`: `None`): The advertised port of this Node.
|
||||
0
docs/reference/scatterbrained/version.py.md
Normal file
0
docs/reference/scatterbrained/version.py.md
Normal file
3
examples/README.md
Normal file
3
examples/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Scatterbrained: Usage Examples
|
||||
|
||||
This directory contains examples of how to use the Scatterbrained federated learning library. For a quick getting-started guide, see the [Getting Started](../docs/getting-started/README.md) tutorial.
|
||||
63
examples/discovery_engine.py
Normal file
63
examples/discovery_engine.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
|
||||
import scatterbrained as sb
|
||||
|
||||
|
||||
async def on_appear(v):
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Appear: {v}")
|
||||
|
||||
|
||||
async def on_disappear(v):
|
||||
await asyncio.sleep(0.1)
|
||||
logger.info(f"Disappear: {v}")
|
||||
|
||||
|
||||
async def on_error(e):
|
||||
await asyncio.sleep(0.1)
|
||||
logger.opt(exception=e).error("local error")
|
||||
|
||||
|
||||
async def on_remote_recv(v):
|
||||
logger.info(f"Remote: {v}")
|
||||
|
||||
|
||||
async def on_remote_error(e):
|
||||
logger.opt(exception=e).error("remote error")
|
||||
|
||||
|
||||
async def main():
|
||||
# NOTE: in a real deployment you'd want everything to use the same port, but because we're running on the
|
||||
# same system here, we need to bind to different ports.
|
||||
local_pub = sb.discovery.udp.UDPBroadcaster("127.0.0.1", port=9002)
|
||||
local_sub = sb.discovery.udp.UDPReceiver("127.0.0.1", port=9001)
|
||||
# Fake a remote node.
|
||||
remote_pub = sb.discovery.udp.UDPBroadcaster("127.0.0.1", port=9001)
|
||||
remote_sub = sb.discovery.udp.UDPReceiver("127.0.0.1", port=9002)
|
||||
|
||||
await asyncio.wait([local_pub.open(), local_sub.open(), remote_pub.open(), remote_sub.open()])
|
||||
|
||||
engine = sb.discovery.DiscoveryEngine(
|
||||
local_pub,
|
||||
local_sub,
|
||||
identities=[sb.types.Identity(id="baz", namespace="bar", host="omg", port=3223)],
|
||||
heartbeat=2,
|
||||
)
|
||||
await engine.start(on_appear=on_appear, on_disappear=on_disappear, on_error=on_error)
|
||||
|
||||
peer = sb.types.Identity(id="foo", namespace="bar", host="meme", port=32233)
|
||||
|
||||
remote_sub.subscribe(on_recv=on_remote_recv, on_error=on_error)
|
||||
await remote_pub.publish(json.dumps(dataclasses.asdict(peer)).encode())
|
||||
await asyncio.sleep(15)
|
||||
|
||||
await engine.stop()
|
||||
await asyncio.wait([local_pub.close(), local_sub.close(), remote_pub.close(), remote_sub.close()])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
48
examples/network_engine.py
Normal file
48
examples/network_engine.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
|
||||
import scatterbrained as sb
|
||||
|
||||
|
||||
async def send_msgs(engine, identity, peer, n=10):
|
||||
for i in range(n):
|
||||
await engine.send_to(identity, peer, f"hello there x{i}".encode())
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
|
||||
async def main():
|
||||
id1 = sb.types.Identity("foo", "baz", "127.0.0.1", 9001)
|
||||
id2 = sb.types.Identity("bar", "baz", "127.0.0.1", 9002)
|
||||
|
||||
rx1 = sb.network.ZMQReceiver()
|
||||
rx2 = sb.network.ZMQReceiver()
|
||||
|
||||
ne1 = sb.network.NetworkEngine(rx1, lambda: sb.network.ZMQTransmitter("foo"))
|
||||
ne2 = sb.network.NetworkEngine(rx2, lambda: sb.network.ZMQTransmitter("bar"))
|
||||
|
||||
await ne1.bind(id1.host, id1.port)
|
||||
await ne2.bind(id2.host, id2.port)
|
||||
|
||||
await ne1.connect_to(id2)
|
||||
await ne2.connect_to(id1)
|
||||
|
||||
async def on_recv(identity, payload):
|
||||
print("RECV:", identity, payload)
|
||||
|
||||
async def on_malformed(peer_id, segments):
|
||||
print("MALFORMED:", peer_id, segments)
|
||||
|
||||
async def on_error(ex):
|
||||
print("ERROR", ex)
|
||||
|
||||
d1 = ne1.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
|
||||
d2 = ne2.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
|
||||
|
||||
await asyncio.wait([send_msgs(ne1, id1, id2), send_msgs(ne2, id2, id1)])
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
await d1.dispose()
|
||||
await d2.dispose()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
32
examples/node.py
Normal file
32
examples/node.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import asyncio
|
||||
|
||||
import scatterbrained as sb
|
||||
|
||||
|
||||
async def main():
|
||||
# NOTE: in a real deployment you'd want everything to use the same port, but because we're running on the
|
||||
# same system here, we need to bind to different ports.
|
||||
de1 = sb.discovery.DiscoveryEngine(
|
||||
publisher=sb.discovery.UDPBroadcaster("127.0.0.1", port=9002),
|
||||
subscriber=sb.discovery.UDPReceiver("127.0.0.1", port=9001),
|
||||
heartbeat=2,
|
||||
)
|
||||
de2 = sb.discovery.DiscoveryEngine(
|
||||
publisher=sb.discovery.UDPBroadcaster("127.0.0.1", port=9001),
|
||||
subscriber=sb.discovery.UDPReceiver("127.0.0.1", port=9002),
|
||||
heartbeat=2,
|
||||
)
|
||||
async with sb.Node(id="foo", host="127.0.0.1", discovery_engine=de1) as node1, sb.Node(
|
||||
id="bar", host="127.0.0.1", discovery_engine=de2
|
||||
) as node2:
|
||||
async with node1.namespace(name="foobar") as ns1, node2.namespace(name="foobar") as ns2:
|
||||
await asyncio.gather(ns1.wait_for_peers(peers="bar"), ns2.wait_for_peers(peers="foo"))
|
||||
await ns1.send_to(ns2._id, b"hello")
|
||||
sender, payload = await ns2.recv(5.0)
|
||||
print(sender, payload)
|
||||
await asyncio.sleep(10)
|
||||
await asyncio.sleep(15)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ["py38"]
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
ensure_newline_before_comments = true
|
||||
line_length = 120
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["unit", "integration", "end2end"]
|
||||
12
requirements/dev-requirements.in
Normal file
12
requirements/dev-requirements.in
Normal file
@@ -0,0 +1,12 @@
|
||||
-c requirements.txt
|
||||
bandit
|
||||
black
|
||||
docshund
|
||||
flake8
|
||||
flake8-bugbear
|
||||
isort
|
||||
notebook
|
||||
pre-commit
|
||||
pytest
|
||||
pytest-asyncio
|
||||
safety
|
||||
266
requirements/dev-requirements.txt
Normal file
266
requirements/dev-requirements.txt
Normal file
@@ -0,0 +1,266 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with python 3.8
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile dev-requirements.in
|
||||
#
|
||||
appnope==0.1.2
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
argon2-cffi==21.1.0
|
||||
# via notebook
|
||||
attrs==21.2.0
|
||||
# via
|
||||
# flake8-bugbear
|
||||
# jsonschema
|
||||
# pytest
|
||||
backcall==0.2.0
|
||||
# via ipython
|
||||
backports.entry-points-selectable==1.1.0
|
||||
# via virtualenv
|
||||
bandit==1.7.0
|
||||
# via -r dev-requirements.in
|
||||
black==21.9b0
|
||||
# via -r dev-requirements.in
|
||||
bleach==4.1.0
|
||||
# via nbconvert
|
||||
certifi==2021.5.30
|
||||
# via requests
|
||||
cffi==1.14.6
|
||||
# via argon2-cffi
|
||||
cfgv==3.3.1
|
||||
# via pre-commit
|
||||
charset-normalizer==2.0.6
|
||||
# via requests
|
||||
click==8.0.1
|
||||
# via
|
||||
# black
|
||||
# docshund
|
||||
# safety
|
||||
debugpy==1.4.3
|
||||
# via ipykernel
|
||||
decorator==5.1.0
|
||||
# via ipython
|
||||
defusedxml==0.7.1
|
||||
# via nbconvert
|
||||
distlib==0.3.3
|
||||
# via virtualenv
|
||||
docshund==0.1.2
|
||||
# via -r dev-requirements.in
|
||||
dparse==0.5.1
|
||||
# via safety
|
||||
entrypoints==0.3
|
||||
# via
|
||||
# jupyter-client
|
||||
# nbconvert
|
||||
filelock==3.0.12
|
||||
# via virtualenv
|
||||
flake8==3.9.2
|
||||
# via
|
||||
# -r dev-requirements.in
|
||||
# flake8-bugbear
|
||||
flake8-bugbear==21.9.1
|
||||
# via -r dev-requirements.in
|
||||
gitdb==4.0.7
|
||||
# via gitpython
|
||||
gitpython==3.1.24
|
||||
# via bandit
|
||||
identify==2.2.15
|
||||
# via pre-commit
|
||||
idna==3.2
|
||||
# via requests
|
||||
iniconfig==1.1.1
|
||||
# via pytest
|
||||
ipykernel==6.4.1
|
||||
# via notebook
|
||||
ipython==7.27.0
|
||||
# via ipykernel
|
||||
ipython-genutils==0.2.0
|
||||
# via
|
||||
# ipykernel
|
||||
# nbformat
|
||||
# notebook
|
||||
isort==5.9.3
|
||||
# via -r dev-requirements.in
|
||||
jedi==0.18.0
|
||||
# via ipython
|
||||
jinja2==3.0.1
|
||||
# via
|
||||
# nbconvert
|
||||
# notebook
|
||||
jsonschema==3.2.0
|
||||
# via nbformat
|
||||
jupyter-client==7.0.3
|
||||
# via
|
||||
# ipykernel
|
||||
# nbclient
|
||||
# notebook
|
||||
jupyter-core==4.8.1
|
||||
# via
|
||||
# jupyter-client
|
||||
# nbconvert
|
||||
# nbformat
|
||||
# notebook
|
||||
jupyterlab-pygments==0.1.2
|
||||
# via nbconvert
|
||||
markupsafe==2.0.1
|
||||
# via jinja2
|
||||
matplotlib-inline==0.1.3
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
mccabe==0.6.1
|
||||
# via flake8
|
||||
mistune==0.8.4
|
||||
# via nbconvert
|
||||
mypy-extensions==0.4.3
|
||||
# via black
|
||||
nbclient==0.5.4
|
||||
# via nbconvert
|
||||
nbconvert==6.1.0
|
||||
# via notebook
|
||||
nbformat==5.1.3
|
||||
# via
|
||||
# nbclient
|
||||
# nbconvert
|
||||
# notebook
|
||||
nest-asyncio==1.5.1
|
||||
# via
|
||||
# jupyter-client
|
||||
# nbclient
|
||||
nodeenv==1.6.0
|
||||
# via pre-commit
|
||||
notebook==6.4.4
|
||||
# via -r dev-requirements.in
|
||||
packaging==21.0
|
||||
# via
|
||||
# bleach
|
||||
# dparse
|
||||
# pytest
|
||||
# safety
|
||||
pandocfilters==1.5.0
|
||||
# via nbconvert
|
||||
parso==0.8.2
|
||||
# via jedi
|
||||
pathspec==0.9.0
|
||||
# via black
|
||||
pbr==5.6.0
|
||||
# via stevedore
|
||||
pexpect==4.8.0
|
||||
# via ipython
|
||||
pickleshare==0.7.5
|
||||
# via ipython
|
||||
platformdirs==2.3.0
|
||||
# via
|
||||
# black
|
||||
# virtualenv
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
pre-commit==2.15.0
|
||||
# via -r dev-requirements.in
|
||||
prometheus-client==0.11.0
|
||||
# via notebook
|
||||
prompt-toolkit==3.0.20
|
||||
# via ipython
|
||||
ptyprocess==0.7.0
|
||||
# via
|
||||
# pexpect
|
||||
# terminado
|
||||
py==1.10.0
|
||||
# via pytest
|
||||
pycodestyle==2.7.0
|
||||
# via flake8
|
||||
pycparser==2.20
|
||||
# via cffi
|
||||
pyflakes==2.3.1
|
||||
# via flake8
|
||||
pygments==2.10.0
|
||||
# via
|
||||
# ipython
|
||||
# jupyterlab-pygments
|
||||
# nbconvert
|
||||
pyparsing==2.4.7
|
||||
# via packaging
|
||||
pyrsistent==0.18.0
|
||||
# via jsonschema
|
||||
pytest==6.2.5
|
||||
# via
|
||||
# -r dev-requirements.in
|
||||
# pytest-asyncio
|
||||
pytest-asyncio==0.15.1
|
||||
# via -r dev-requirements.in
|
||||
python-dateutil==2.8.2
|
||||
# via jupyter-client
|
||||
pyyaml==5.4.1
|
||||
# via
|
||||
# bandit
|
||||
# dparse
|
||||
# pre-commit
|
||||
pyzmq==22.3.0
|
||||
# via
|
||||
# -c requirements.txt
|
||||
# jupyter-client
|
||||
# notebook
|
||||
regex==2021.9.24
|
||||
# via black
|
||||
requests==2.26.0
|
||||
# via safety
|
||||
safety==1.10.3
|
||||
# via -r dev-requirements.in
|
||||
send2trash==1.8.0
|
||||
# via notebook
|
||||
six==1.16.0
|
||||
# via
|
||||
# bandit
|
||||
# bleach
|
||||
# jsonschema
|
||||
# python-dateutil
|
||||
# virtualenv
|
||||
smmap==4.0.0
|
||||
# via gitdb
|
||||
stevedore==3.4.0
|
||||
# via bandit
|
||||
terminado==0.12.1
|
||||
# via notebook
|
||||
testpath==0.5.0
|
||||
# via nbconvert
|
||||
toml==0.10.2
|
||||
# via
|
||||
# dparse
|
||||
# pre-commit
|
||||
# pytest
|
||||
tomli==1.2.1
|
||||
# via black
|
||||
tornado==6.1
|
||||
# via
|
||||
# ipykernel
|
||||
# jupyter-client
|
||||
# notebook
|
||||
# terminado
|
||||
traitlets==5.1.0
|
||||
# via
|
||||
# ipykernel
|
||||
# ipython
|
||||
# jupyter-client
|
||||
# jupyter-core
|
||||
# matplotlib-inline
|
||||
# nbclient
|
||||
# nbconvert
|
||||
# nbformat
|
||||
# notebook
|
||||
typing-extensions==3.10.0.2
|
||||
# via
|
||||
# black
|
||||
# gitpython
|
||||
urllib3==1.26.7
|
||||
# via requests
|
||||
virtualenv==20.8.1
|
||||
# via pre-commit
|
||||
wcwidth==0.2.5
|
||||
# via prompt-toolkit
|
||||
webencodings==0.5.1
|
||||
# via bleach
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
5
requirements/requirements.in
Normal file
5
requirements/requirements.in
Normal file
@@ -0,0 +1,5 @@
|
||||
loguru
|
||||
orjson
|
||||
pyzmq
|
||||
rx
|
||||
uvloop
|
||||
16
requirements/requirements.txt
Normal file
16
requirements/requirements.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with python 3.8
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile requirements.in
|
||||
#
|
||||
loguru==0.5.3
|
||||
# via -r requirements.in
|
||||
orjson==3.6.3
|
||||
# via -r requirements.in
|
||||
pyzmq==22.3.0
|
||||
# via -r requirements.in
|
||||
rx==3.2.0
|
||||
# via -r requirements.in
|
||||
uvloop==0.16.0
|
||||
# via -r requirements.in
|
||||
32
setup.py
Normal file
32
setup.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
README = HERE.joinpath("README.md").read_text()
|
||||
REQUIREMENTS = HERE.joinpath("requirements", "requirements.in").read_text().split()
|
||||
|
||||
|
||||
def get_version(rel_path: Path):
|
||||
contents = HERE.joinpath(rel_path).read_text().splitlines()
|
||||
for line in contents:
|
||||
if line.startswith("__version__"):
|
||||
delim = '"' if '"' in line else "'"
|
||||
return line.split(delim)[1]
|
||||
else:
|
||||
raise RuntimeError("Unable to find version string.")
|
||||
|
||||
|
||||
setup(
|
||||
name="scatterbrained",
|
||||
version=get_version(Path("src", "scatterbrained", "version.py")),
|
||||
author="Miller Wilt",
|
||||
author_email="miller.wilt@jhuapl.edu",
|
||||
description="Decentralized Federated Learning framework",
|
||||
long_description=README,
|
||||
long_description_content_type="text/markdown",
|
||||
packages=find_packages("src"),
|
||||
package_dir={"": "src"},
|
||||
install_requires=REQUIREMENTS,
|
||||
python_requires=">=3.8.0",
|
||||
)
|
||||
3
src/scatterbrained/__init__.py
Normal file
3
src/scatterbrained/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import discovery, network, types # noqa: F401
|
||||
from .node import Node # noqa: F401
|
||||
from .version import __version__ # noqa: F401
|
||||
3
src/scatterbrained/discovery/__init__.py
Normal file
3
src/scatterbrained/discovery/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import types # noqa: F401
|
||||
from .engine import DiscoveryEngine # noqa: F401
|
||||
from .udp import UDPBroadcaster, UDPReceiver # noqa: F401
|
||||
239
src/scatterbrained/discovery/engine.py
Normal file
239
src/scatterbrained/discovery/engine.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
The Discovery Engine is responsible for discovering and managing nodes on the Scatterbrained network.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence, Set, Tuple
|
||||
|
||||
import orjson
|
||||
import rx
|
||||
|
||||
from ..types import Identity
|
||||
from .types import Publisher, Subscriber
|
||||
|
||||
|
||||
class DiscoveryEngine:
|
||||
"""
|
||||
A `DiscoveryEngine` is a manager class for the process of identifying and communicating with other peers in the
|
||||
Scatterbrained network.
|
||||
|
||||
The `DiscoveryEngine` is responsible for maintaining a list of peers, and periodically sending heartbeats to the
|
||||
network, which other peers can use to determine if the sending node is still alive.
|
||||
|
||||
Identities of peers are stored in `DiscoveryEngine.peers`.
|
||||
|
||||
Identities of the current node are stored in `DiscoveryEngine.identities` and can be manually added or removed using
|
||||
`DiscoveryEngine.add_identity` and `DiscoveryEngine.remove_identity`.
|
||||
"""
|
||||
|
||||
_publisher: Publisher
|
||||
_subscriber: Subscriber
|
||||
_heartbeat: float
|
||||
_identities: Dict[Tuple[str, str], Identity]
|
||||
_subscription: Optional[rx.disposable.Disposable]
|
||||
_heartbeat_task: Optional[asyncio.Task]
|
||||
_lifetime_task: Optional[asyncio.Task]
|
||||
_peers: Set[Identity]
|
||||
_last_seen: Dict[Identity, datetime]
|
||||
|
||||
def __init__(
|
||||
self, publisher: Publisher, subscriber: Subscriber, identities: Optional[Sequence] = None, heartbeat: float = 5
|
||||
) -> None:
|
||||
"""
|
||||
Create a new DiscoveryEngine with a publisher and subscriber.
|
||||
|
||||
Arguments:
|
||||
publisher (scatterbrained.discovery.types.Publisher): The publisher to use for sending peer transactions.
|
||||
subscriber (scatterbrained.discovery.types.Subscriber): The subscriber to use for receiving peer
|
||||
transactions.
|
||||
identities (Optional[Sequence[scatterbrained.discovery.types.Identity]]): A list of identities to use for
|
||||
the initial peer list.
|
||||
heartbeat (float): The heartbeat interval in seconds.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if identities is None:
|
||||
identities = []
|
||||
self.heartbeat = heartbeat
|
||||
self._publisher = publisher
|
||||
self._subscriber = subscriber
|
||||
self._identities = {(i.id, i.namespace): i for i in identities}
|
||||
self._subscription = None
|
||||
self._heartbeat_task = None
|
||||
self._lifetime_task = None
|
||||
self._peers = set()
|
||||
self._last_seen = {}
|
||||
|
||||
@property
|
||||
def heartbeat(self):
|
||||
"""
|
||||
Retrieve the heartbeat interval in seconds.
|
||||
"""
|
||||
return self._heartbeat
|
||||
|
||||
@heartbeat.setter
|
||||
def heartbeat(self, value):
|
||||
"""
|
||||
Set the heartbeat interval in seconds.
|
||||
|
||||
Arguments:
|
||||
value (float): The heartbeat interval in seconds.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if value < 1:
|
||||
raise ValueError(f"value must be 1 or more (value={value})")
|
||||
self._heartbeat = value
|
||||
|
||||
@property
|
||||
def peers(self):
|
||||
"""
|
||||
Retrieve the list of peers.
|
||||
|
||||
Returns:
|
||||
Set[scatterbrained.discovery.types.Identity]: The list of peers.
|
||||
|
||||
"""
|
||||
# Defensive copy.
|
||||
return set(self._peers)
|
||||
|
||||
def add_identity(self, identity: Identity) -> None:
|
||||
"""
|
||||
Add an identity to the list of identities for the current node.
|
||||
|
||||
Arguments:
|
||||
identity (scatterbrained.discovery.types.Identity): The identity to add.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
self._identities[(identity.id, identity.namespace)] = identity
|
||||
|
||||
def remove_identity(self, identity: Identity) -> None:
|
||||
"""
|
||||
Remove an identity from the list of valid identities for the node.
|
||||
|
||||
Arguments:
|
||||
identity (scatterbrained.discovery.types.Identity): The identity to remove.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
self._identities.pop((identity.id, identity.namespace), None)
|
||||
|
||||
async def start(
|
||||
self,
|
||||
on_appear: Callable[[Identity], Awaitable[None]],
|
||||
on_disappear: Callable[[Identity], Awaitable[None]],
|
||||
on_error: Callable[[Exception], Awaitable[None]],
|
||||
) -> None:
|
||||
"""
|
||||
Start the DiscoveryEngine, with awaitable callbacks per peer.
|
||||
|
||||
Callbacks receive an identity of the peer that appeared or disappeared. The `on_error` callback is called
|
||||
if an error occurs, and it receives as arguments the exception. Note that all callbacks are async.
|
||||
|
||||
Arguments:
|
||||
on_appear (callback): A callback to call when a peer appears on the scatterbrained network.
|
||||
on_disappear (callback): A callback to call when a peer disappears from the network.
|
||||
on_error (exception_handler): A callback to run when an exception is thrown communicating with a peer.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if self._subscription is not None:
|
||||
return
|
||||
|
||||
async def handle_peer_heartbeat(data):
|
||||
peer_info = orjson.loads(data)
|
||||
peer = Identity(**peer_info)
|
||||
# TODO: the proper way to check if a heartbeat was generated
|
||||
# locally is to look at the source ip, and cross-check that with
|
||||
# known interfaces. This will work for now, however.
|
||||
if (peer.id, peer.namespace) in self._identities:
|
||||
return
|
||||
elif peer not in self._peers:
|
||||
self._peers.add(peer)
|
||||
self._last_seen[peer] = datetime.now()
|
||||
await on_appear(peer)
|
||||
else:
|
||||
# Update time of last seen
|
||||
self._last_seen[peer] = datetime.now()
|
||||
|
||||
await self._publisher.open()
|
||||
await self._subscriber.open()
|
||||
self._subscription = self._subscriber.subscribe(on_recv=handle_peer_heartbeat, on_error=on_error)
|
||||
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
||||
self._lifetime_task = asyncio.create_task(self._lifetime_monitor(on_disappear))
|
||||
|
||||
async def stop(self):
|
||||
"""
|
||||
Gracefully stop the `DiscoveryEngine`.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if self._subscription is None:
|
||||
return
|
||||
|
||||
assert self._heartbeat_task is not None # nosec
|
||||
assert self._lifetime_task is not None # nosec
|
||||
|
||||
self._subscription.dispose()
|
||||
self._subscription = None
|
||||
|
||||
self._heartbeat_task.cancel()
|
||||
await self._heartbeat_task
|
||||
self._heartbeat_task = None
|
||||
|
||||
self._lifetime_task.cancel()
|
||||
await self._lifetime_task
|
||||
self._lifetime_task = None
|
||||
|
||||
await self._publisher.close()
|
||||
await self._subscriber.close()
|
||||
|
||||
self._peers = set()
|
||||
self._last_seen = {}
|
||||
|
||||
async def _heartbeat_loop(self):
|
||||
try:
|
||||
while True:
|
||||
for id in self._identities.values():
|
||||
obj = dataclasses.asdict(id)
|
||||
obj = orjson.dumps(obj)
|
||||
await self._publisher.publish(obj)
|
||||
await asyncio.sleep(self._heartbeat)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _lifetime_monitor(self, on_disappear: Callable[[Identity], Awaitable[None]]):
|
||||
try:
|
||||
while True:
|
||||
now = datetime.now()
|
||||
tasks = []
|
||||
to_remove = set()
|
||||
for peer, last_seen in self._last_seen.items():
|
||||
if (now - last_seen).total_seconds() >= self._heartbeat * 5:
|
||||
self._peers.remove(peer)
|
||||
to_remove.add(peer)
|
||||
task = asyncio.create_task(on_disappear(peer))
|
||||
tasks.append(task)
|
||||
# NOTE: not optimal, as tasks may take a long time, but will suffice for now.
|
||||
if tasks:
|
||||
for peer in to_remove:
|
||||
self._last_seen.pop(peer)
|
||||
await asyncio.wait(tasks)
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
70
src/scatterbrained/discovery/types.py
Normal file
70
src/scatterbrained/discovery/types.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Awaitable, Callable, Optional, Protocol
|
||||
|
||||
|
||||
class Publisher(Protocol):
|
||||
"""
|
||||
A Publisher is a class that can publish a message to a set of peers.
|
||||
|
||||
Note that this is a protocol and shouldn't be used directly!
|
||||
|
||||
"""
|
||||
|
||||
async def publish(self, data: bytes) -> None:
|
||||
"""
|
||||
Publish the given payload to a set of peers.
|
||||
"""
|
||||
...
|
||||
|
||||
async def open(self) -> None:
|
||||
"""
|
||||
Open the underlying connection mechanism, enabling this instance to send messages.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the underlying connection mechanism, stopping this instance from sending messages.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Subscriber(Protocol):
|
||||
"""
|
||||
A Subscriber is a class that can subscribe to messages from a set of peers.
|
||||
|
||||
Note that this is a protocol and shouldn't be used directly!
|
||||
|
||||
"""
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_recv: Callable[[bytes], Awaitable[None]],
|
||||
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to messages from a set of peers, and attach async callbacks.
|
||||
|
||||
Arguments:
|
||||
on_recv (Callable[[bytes], Awaitable[None]]): The callback to call when a message is received.
|
||||
on_error (Optional[Callable[[Exception], Awaitable[None]]]): The callback to run when an error occurs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
async def open(self) -> None:
|
||||
"""
|
||||
Open the underlying connection mechanism, enabling this instance to receive messages.
|
||||
"""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the underlying connection mechanism, stopping this instance from receiving messages.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
__all__ = ["Publisher", "Subscriber"]
|
||||
145
src/scatterbrained/discovery/udp.py
Normal file
145
src/scatterbrained/discovery/udp.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable, Optional, Tuple
|
||||
|
||||
import rx
|
||||
from loguru import logger
|
||||
from rx import operators as op
|
||||
from rx.scheduler.eventloop import AsyncIOScheduler
|
||||
from rx.subject import Subject
|
||||
|
||||
|
||||
class _UDPBroadcastProtocol:
|
||||
def connection_made(self, transport: asyncio.DatagramTransport):
|
||||
self.transport = transport
|
||||
|
||||
def connection_lost(self, exc: Exception):
|
||||
# TODO: determine how to inform larger program about any exceptions passed.
|
||||
self.transport = None
|
||||
|
||||
def error_received(self, exc: OSError):
|
||||
# TODO: determine how to inform larger program about any exceptions passed.
|
||||
pass
|
||||
|
||||
|
||||
class _UDPRecvProtocol:
|
||||
def __init__(self, subject: Optional[Subject] = None):
|
||||
if subject is None:
|
||||
subject = Subject()
|
||||
self.subject = subject
|
||||
|
||||
def connection_made(self, transport: asyncio.DatagramTransport):
|
||||
self.transport = transport
|
||||
|
||||
def connection_lost(self, exc: Exception):
|
||||
self.transport = None
|
||||
if exc is None:
|
||||
self.subject.on_completed()
|
||||
else:
|
||||
self.subject.on_error(exc)
|
||||
|
||||
def error_received(self, exc: OSError):
|
||||
self.subject.on_error(exc)
|
||||
|
||||
def datagram_received(self, data: bytes, addr: Tuple[str, int]):
|
||||
self.subject.on_next((data, addr))
|
||||
|
||||
|
||||
class UDPBroadcaster:
|
||||
remote: str
|
||||
port: int
|
||||
_transport: Optional[asyncio.DatagramTransport]
|
||||
_protocol: Optional[asyncio.DatagramProtocol]
|
||||
|
||||
def __init__(self, broadcast_addr: str = "255.255.255.255", port: int = 9001):
|
||||
self.broadcast_addr = broadcast_addr
|
||||
self.port = port
|
||||
self._transport = None
|
||||
self._protocol = None
|
||||
|
||||
async def open(self) -> None:
|
||||
if self._transport is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._transport, self._protocol = await loop.create_datagram_endpoint(
|
||||
lambda: _UDPBroadcastProtocol(), remote_addr=(self.broadcast_addr, self.port), allow_broadcast=True
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._transport is not None:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._protocol = None
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.open()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
|
||||
async def publish(self, data: bytes) -> None:
|
||||
if self._transport is None:
|
||||
raise RuntimeError("socket is not open")
|
||||
self._transport.sendto(data)
|
||||
|
||||
|
||||
class UDPReceiver:
|
||||
source: str
|
||||
port: int
|
||||
_transport: Optional[asyncio.DatagramTransport]
|
||||
_protocol: Optional[asyncio.DatagramProtocol]
|
||||
|
||||
def __init__(self, local_addr: str = "", port: int = 9001, subject: Optional[Subject] = None):
|
||||
if subject is None:
|
||||
subject = Subject()
|
||||
self.local_addr = local_addr
|
||||
self.port = port
|
||||
self.subject = subject
|
||||
self._transport = None
|
||||
self._protocol = None
|
||||
|
||||
async def open(self) -> None:
|
||||
if self._transport is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._transport, self._protocol = await loop.create_datagram_endpoint(
|
||||
lambda: _UDPRecvProtocol(subject=self.subject), local_addr=(self.local_addr, self.port)
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._transport is not None:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._protocol = None
|
||||
|
||||
async def __aenter__(self) -> None:
|
||||
await self.open()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_recv: Callable[[bytes], Awaitable[None]],
|
||||
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
||||
):
|
||||
def handle_next(obj):
|
||||
data, _ = obj
|
||||
return rx.from_future(asyncio.create_task(on_recv(data)))
|
||||
|
||||
def handle_error(ex, src):
|
||||
assert on_error is not None # nosec
|
||||
return rx.from_future(asyncio.create_task(on_error(ex)))
|
||||
|
||||
workflow = self.subject.pipe(op.flat_map(handle_next))
|
||||
if on_error is not None:
|
||||
workflow = workflow.pipe(op.catch(handle_error))
|
||||
|
||||
d = workflow.subscribe(
|
||||
on_next=lambda _: logger.trace("message processed"),
|
||||
on_error=lambda e: logger.opt(exception=e).error(f"{type(self).__name__} data stream error"),
|
||||
on_completed=lambda: logger.trace(f"{type(self).__name__} data stream ended"),
|
||||
scheduler=AsyncIOScheduler(loop=asyncio.get_running_loop()),
|
||||
)
|
||||
|
||||
return d
|
||||
|
||||
|
||||
__all__ = ["UDPBroadcaster", "UDPReceiver"]
|
||||
154
src/scatterbrained/discovery/udp_test.py
Normal file
154
src/scatterbrained/discovery/udp_test.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
import socket
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from . import udp
|
||||
|
||||
# All test coroutines will be treated as marked.
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def addr():
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def port():
|
||||
return 9001
|
||||
|
||||
|
||||
async def test_UDPBroadcaster_full(event_loop, addr, port):
|
||||
# To simplify things, just use a blocking socket
|
||||
recv_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
recv_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
recv_socket.settimeout(3)
|
||||
await event_loop.run_in_executor(None, recv_socket.bind, ("", port))
|
||||
|
||||
broadcaster = udp.UDPBroadcaster(broadcast_addr=addr, port=port)
|
||||
assert broadcaster._transport is None and broadcaster._protocol is None
|
||||
|
||||
await broadcaster.open()
|
||||
assert broadcaster._transport is not None and broadcaster._protocol is not None
|
||||
|
||||
expected = b"hello world"
|
||||
await broadcaster.publish(expected)
|
||||
actual, _ = await event_loop.run_in_executor(None, recv_socket.recvfrom, 255)
|
||||
assert actual == expected
|
||||
|
||||
await broadcaster.close()
|
||||
assert broadcaster._transport is None and broadcaster._protocol is None
|
||||
|
||||
|
||||
async def test_UDPBroadcaster_full_async_with(event_loop, addr, port):
|
||||
recv_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
recv_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
|
||||
recv_socket.settimeout(3)
|
||||
await event_loop.run_in_executor(None, recv_socket.bind, ("", port))
|
||||
|
||||
broadcaster = udp.UDPBroadcaster(broadcast_addr=addr, port=port)
|
||||
assert broadcaster._transport is None and broadcaster._protocol is None
|
||||
|
||||
async with broadcaster:
|
||||
assert broadcaster._transport is not None and broadcaster._protocol is not None
|
||||
|
||||
expected = b"hello world"
|
||||
await broadcaster.publish(expected)
|
||||
actual, _ = await event_loop.run_in_executor(None, recv_socket.recvfrom, 255)
|
||||
assert actual == expected
|
||||
|
||||
assert broadcaster._transport is None and broadcaster._protocol is None
|
||||
|
||||
|
||||
async def test_UDPReceiver_full(event_loop, addr, port):
|
||||
async def mock_on_recv(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
# To simplify things, just use a blocking socket
|
||||
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
|
||||
receiver = udp.UDPReceiver(local_addr=addr, port=port)
|
||||
assert receiver._transport is None and receiver._protocol is None
|
||||
|
||||
await receiver.open()
|
||||
assert receiver._transport is not None and receiver._protocol is not None
|
||||
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock(side_effect=mock_on_recv)
|
||||
on_error = AsyncMock()
|
||||
receiver.subscribe(on_recv=on_recv, on_error=on_error)
|
||||
|
||||
expected = b"hello world"
|
||||
send_socket.sendto(expected, (addr, port))
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
on_recv.assert_awaited_once_with(expected)
|
||||
on_error.assert_not_awaited()
|
||||
|
||||
await receiver.close()
|
||||
assert receiver._transport is None and receiver._protocol is None
|
||||
|
||||
|
||||
async def test_UDPReceiver_async_with(event_loop, addr, port):
|
||||
async def mock_on_recv(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
# To simplify things, just use a blocking socket
|
||||
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
|
||||
receiver = udp.UDPReceiver(local_addr=addr, port=port)
|
||||
assert receiver._transport is None and receiver._protocol is None
|
||||
|
||||
async with receiver:
|
||||
assert receiver._transport is not None and receiver._protocol is not None
|
||||
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock(side_effect=mock_on_recv)
|
||||
on_error = AsyncMock()
|
||||
receiver.subscribe(on_recv=on_recv, on_error=on_error)
|
||||
|
||||
expected = b"hello world"
|
||||
send_socket.sendto(expected, (addr, port))
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
on_recv.assert_awaited_once_with(expected)
|
||||
on_error.assert_not_awaited()
|
||||
assert receiver._transport is None and receiver._protocol is None
|
||||
|
||||
|
||||
async def test_UDPReceiver_error(event_loop, addr, port):
|
||||
async def mock_on_recv(*args):
|
||||
await asyncio.sleep(0.01)
|
||||
raise RuntimeError("yeet")
|
||||
|
||||
async def mock_on_error(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
# To simplify things, just use a blocking socket
|
||||
send_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
|
||||
|
||||
receiver = udp.UDPReceiver(local_addr=addr, port=port)
|
||||
assert receiver._transport is None and receiver._protocol is None
|
||||
|
||||
async with receiver:
|
||||
assert receiver._transport is not None and receiver._protocol is not None
|
||||
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock(side_effect=mock_on_recv)
|
||||
on_error = AsyncMock(side_effect=mock_on_error)
|
||||
receiver.subscribe(on_recv=on_recv, on_error=on_error)
|
||||
|
||||
expected = b"hello world"
|
||||
send_socket.sendto(expected, (addr, port))
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
on_recv.assert_awaited_once_with(expected)
|
||||
on_error.assert_awaited_once()
|
||||
2
src/scatterbrained/network/__init__.py
Normal file
2
src/scatterbrained/network/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .engine import NetworkEngine # noqa: F401
|
||||
from .zmq_ip import ZMQReceiver, ZMQTransmitter # noqa: F401
|
||||
165
src/scatterbrained/network/engine.py
Normal file
165
src/scatterbrained/network/engine.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import orjson
|
||||
import rx
|
||||
from loguru import logger
|
||||
from rx import operators as op
|
||||
from rx.scheduler.eventloop import AsyncIOScheduler
|
||||
from rx.subject import Subject
|
||||
|
||||
from ..types import Identity
|
||||
from .types import RX, TX
|
||||
|
||||
|
||||
class NetworkEngine:
|
||||
class Disposable:
|
||||
engine: NetworkEngine
|
||||
rx_disposable: rx.disposable.Disposable
|
||||
_disposed: bool
|
||||
|
||||
def __init__(self, engine: NetworkEngine, rx_disposable: rx.disposable.Disposable):
|
||||
self.engine = engine
|
||||
self.rx_disposable = rx_disposable
|
||||
self._disposed = False
|
||||
|
||||
async def dispose(self) -> None:
|
||||
if self._disposed:
|
||||
return
|
||||
|
||||
self.rx_disposable.dispose()
|
||||
self.engine._sub_count -= 1
|
||||
if self.engine._sub_count == 0:
|
||||
# Need to be careful about concurrency here. To prevent race conditions, we save/clear the rx_task
|
||||
# field, cancel the saved task, and then await on it. This ensures that if someone subscribes while the
|
||||
# task is being awaited it starts the rx loop back up.
|
||||
task = self.engine._rx_task
|
||||
self.engine._rx_task = None
|
||||
# Need to check this in case the NetworkEngine instance was unbound/closed.
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
await task
|
||||
|
||||
_tx_factory: Callable[[], TX]
|
||||
_rx: RX
|
||||
_physical_tx_connections: Dict[Tuple[str, int], TX]
|
||||
_bound: bool
|
||||
_subject: Subject
|
||||
_rx_task: Optional[asyncio.Task]
|
||||
_sub_count: int
|
||||
|
||||
def __init__(self, rx: RX, tx_factory: Callable[[], TX]):
|
||||
self._rx = rx
|
||||
self._tx_factory = tx_factory
|
||||
self._physical_tx_connections = {}
|
||||
self._bound = False
|
||||
self._subject = Subject()
|
||||
self._rx_task = None
|
||||
self._sub_count = 0
|
||||
|
||||
@property
|
||||
def bound_address(self):
|
||||
return self._rx.host, self._rx.port
|
||||
|
||||
async def bind(self, host: str, port: Optional[int] = None) -> int:
|
||||
port = await self._rx.bind(host, port)
|
||||
self._bound = True
|
||||
return port
|
||||
|
||||
async def unbind(self) -> None:
|
||||
if self._rx_task is not None:
|
||||
self._rx_task.cancel()
|
||||
await self._rx_task
|
||||
self._rx_task = None
|
||||
await self._rx.close()
|
||||
self._bound = False
|
||||
|
||||
async def close(self) -> None:
|
||||
keys = list(self._physical_tx_connections)
|
||||
for k in keys:
|
||||
tx = self._physical_tx_connections.pop(k)
|
||||
await tx.close()
|
||||
await self.unbind()
|
||||
|
||||
async def connect_to(self, host: str, port: int) -> None:
|
||||
phy_conn_info = (host, port)
|
||||
if phy_conn_info not in self._physical_tx_connections:
|
||||
tx = self._tx_factory()
|
||||
await tx.connect(host, port)
|
||||
|
||||
self._physical_tx_connections[phy_conn_info] = tx
|
||||
|
||||
async def disconnect_from(self, host: str, port: int) -> None:
|
||||
phy_conn_info = (host, port)
|
||||
if phy_conn_info in self._physical_tx_connections:
|
||||
tx = self._physical_tx_connections.pop(phy_conn_info)
|
||||
await tx.close()
|
||||
|
||||
async def send_to(self, id: Identity, peer: Identity, *payload: bytes) -> None:
|
||||
phy_conn_info = (peer.host, peer.port)
|
||||
if phy_conn_info not in self._physical_tx_connections:
|
||||
raise RuntimeError(f"not connected to peer node '{peer.id}' (addr={peer.host}:{peer.port})")
|
||||
|
||||
tx = self._physical_tx_connections[phy_conn_info]
|
||||
idbytes = orjson.dumps(dataclasses.asdict(id))
|
||||
await tx.send(idbytes, *payload)
|
||||
|
||||
async def _rx_loop(self) -> None:
|
||||
logger.debug(f"{type(self).__name__} RX loop is starting")
|
||||
try:
|
||||
while True:
|
||||
peer_id, segments = await self._rx.recv()
|
||||
try:
|
||||
identity = orjson.loads(segments[0])
|
||||
identity = Identity(**identity)
|
||||
except Exception:
|
||||
logger.exception(f"failed to extract identity from message sent by '{peer_id}'")
|
||||
self._subject.on_next((True, (peer_id, segments)))
|
||||
else:
|
||||
self._subject.on_next((False, (identity, segments[1:])))
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"{type(self).__name__} RX loop is shutting down")
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(f"error in {type(self).__name__} RX loop")
|
||||
self._subject.on_error(e)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
on_recv: Callable[[Identity, Sequence[bytes]], Awaitable[None]],
|
||||
on_malformed: Callable[[str, Sequence[bytes]], Awaitable[None]],
|
||||
on_error: Optional[Callable[[Exception], Awaitable[None]]] = None,
|
||||
):
|
||||
if not self._bound:
|
||||
raise RuntimeError(f"{type(self).__name__} must be bound before subscribe can be called")
|
||||
|
||||
def handle_next(data):
|
||||
malformed, payload = data
|
||||
if malformed:
|
||||
return rx.from_future(asyncio.create_task(on_malformed(*payload)))
|
||||
else:
|
||||
return rx.from_future(asyncio.create_task(on_recv(*payload)))
|
||||
|
||||
def handle_error(ex, src):
|
||||
assert on_error is not None # nosec
|
||||
return rx.from_future(asyncio.create_task(on_error(ex)))
|
||||
|
||||
workflow = self._subject.pipe(op.flat_map(handle_next))
|
||||
if on_error is not None:
|
||||
workflow = workflow.pipe(op.catch(handle_error))
|
||||
|
||||
d = workflow.subscribe(
|
||||
on_next=lambda _: logger.trace("message processed"),
|
||||
on_error=lambda e: logger.opt(exception=e).error("network engine data stream error"),
|
||||
on_completed=lambda: logger.trace("network engine data stream ended"),
|
||||
scheduler=AsyncIOScheduler(loop=asyncio.get_running_loop()),
|
||||
)
|
||||
|
||||
if self._rx_task is None:
|
||||
self._rx_task = asyncio.create_task(self._rx_loop())
|
||||
self._sub_count += 1
|
||||
|
||||
return self.Disposable(self, d)
|
||||
229
src/scatterbrained/network/engine_test.py
Normal file
229
src/scatterbrained/network/engine_test.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from ..types import Identity
|
||||
from .engine import NetworkEngine
|
||||
|
||||
# All test coroutines will be treated as marked.
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def id1():
|
||||
return Identity(id="1", namespace="foo", host="127.0.0.1", port=9002)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def id2():
|
||||
return Identity(id="2", namespace="bar", host="127.0.0.1", port=9002)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def id3():
|
||||
return Identity(id="3", namespace="baz", host="127.0.0.1", port=9003)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def msg(id1):
|
||||
identity = orjson.dumps(dataclasses.asdict(id1))
|
||||
payload = b"hello there"
|
||||
|
||||
return (id1.id, [identity, payload])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def malformed_msg():
|
||||
return ("foo", [b"hello there"])
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_connect_to(id1, id2, id3):
|
||||
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
|
||||
|
||||
await engine.connect_to(id1.host, id1.port)
|
||||
|
||||
assert list(engine._physical_tx_connections) == [(id1.host, id1.port)]
|
||||
tx = engine._physical_tx_connections[(id1.host, id1.port)]
|
||||
tx.connect.assert_called_once_with(id1.host, id1.port)
|
||||
|
||||
await engine.connect_to(id2.host, id2.port)
|
||||
|
||||
assert list(engine._physical_tx_connections) == [(id1.host, id1.port)]
|
||||
tx = engine._physical_tx_connections[(id1.host, id1.port)]
|
||||
tx.connect.assert_called_once_with(id1.host, id1.port)
|
||||
|
||||
await engine.connect_to(id3.host, id3.port)
|
||||
|
||||
assert list(engine._physical_tx_connections) == [(id1.host, id1.port), (id3.host, id3.port)]
|
||||
tx1 = engine._physical_tx_connections[(id1.host, id1.port)]
|
||||
tx1.connect.assert_called_once_with(id1.host, id1.port)
|
||||
tx2 = engine._physical_tx_connections[(id3.host, id3.port)]
|
||||
tx2.connect.assert_called_once_with(id3.host, id3.port)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_subscribe(msg):
|
||||
async def mock_recv():
|
||||
await asyncio.sleep(0.01)
|
||||
return msg
|
||||
|
||||
async def mock_on_recv(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
port = 9001
|
||||
rx = AsyncMock()
|
||||
rx.bind.return_value = port
|
||||
rx.recv.side_effect = mock_recv
|
||||
|
||||
engine = NetworkEngine(rx, lambda: AsyncMock())
|
||||
|
||||
actual_port = await engine.bind("127.0.0.1", port)
|
||||
assert actual_port == port
|
||||
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock(side_effect=mock_on_recv)
|
||||
on_malformed = AsyncMock()
|
||||
on_error = AsyncMock()
|
||||
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
|
||||
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
await d.dispose()
|
||||
|
||||
expected_id = Identity(**orjson.loads(msg[1][0]))
|
||||
expected_segments = msg[1][1:]
|
||||
on_recv.assert_awaited_with(expected_id, expected_segments)
|
||||
on_malformed.assert_not_awaited()
|
||||
on_error.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_subscribe_malformed(malformed_msg):
|
||||
async def mock_recv():
|
||||
await asyncio.sleep(0.01)
|
||||
return malformed_msg
|
||||
|
||||
async def mock_on_malformed(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
port = 9001
|
||||
rx = AsyncMock()
|
||||
rx.bind.return_value = port
|
||||
rx.recv.side_effect = mock_recv
|
||||
|
||||
engine = NetworkEngine(rx, lambda: AsyncMock())
|
||||
|
||||
actual_port = await engine.bind("127.0.0.1", port)
|
||||
assert actual_port == port
|
||||
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock()
|
||||
on_malformed = AsyncMock(side_effect=mock_on_malformed)
|
||||
on_error = AsyncMock()
|
||||
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
|
||||
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
await d.dispose()
|
||||
|
||||
expected_id = malformed_msg[0]
|
||||
expected_segments = malformed_msg[1]
|
||||
on_malformed.assert_awaited_with(expected_id, expected_segments)
|
||||
on_recv.assert_not_awaited()
|
||||
on_error.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_subscribe_error():
|
||||
async def mock_recv():
|
||||
await asyncio.sleep(0.01)
|
||||
raise ex
|
||||
|
||||
async def mock_on_error(*args):
|
||||
async with cond:
|
||||
cond.notify()
|
||||
|
||||
port = 9001
|
||||
rx = AsyncMock()
|
||||
rx.bind.return_value = port
|
||||
rx.recv.side_effect = mock_recv
|
||||
|
||||
engine = NetworkEngine(rx, lambda: AsyncMock())
|
||||
|
||||
actual_port = await engine.bind("127.0.0.1", port)
|
||||
assert actual_port == port
|
||||
|
||||
ex = ValueError("yeet")
|
||||
cond = asyncio.Condition()
|
||||
on_recv = AsyncMock()
|
||||
on_malformed = AsyncMock()
|
||||
on_error = AsyncMock(side_effect=mock_on_error)
|
||||
d = engine.subscribe(on_recv=on_recv, on_malformed=on_malformed, on_error=on_error)
|
||||
|
||||
async with cond:
|
||||
await asyncio.wait_for(cond.wait(), 3.0)
|
||||
|
||||
await d.dispose()
|
||||
|
||||
on_error.assert_awaited_with(ex)
|
||||
on_recv.assert_not_awaited()
|
||||
on_malformed.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_subscribe_not_bound():
|
||||
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
engine.subscribe(on_recv=AsyncMock(), on_malformed=AsyncMock())
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_send_to(id1, id2):
|
||||
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
|
||||
|
||||
await engine.connect_to(id2.host, id2.port)
|
||||
await engine.send_to(id1, id2, b"hello there")
|
||||
|
||||
tx = engine._physical_tx_connections[(id2.host, id2.port)]
|
||||
tx.send.awaited_once_with(orjson.dumps(dataclasses.asdict(id1)), b"hello there")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_send_to_not_connected(id1, id2):
|
||||
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await engine.send_to(id1, id2, b"hello there")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
async def test_NetworkEngine_disconnect_from(id1, id2, id3):
|
||||
engine = NetworkEngine(AsyncMock(), lambda: AsyncMock())
|
||||
|
||||
await engine.connect_to(id1.host, id1.port)
|
||||
await engine.connect_to(id2.host, id2.port)
|
||||
await engine.connect_to(id3.host, id3.port)
|
||||
|
||||
tx12 = engine._physical_tx_connections[(id1.host, id1.port)]
|
||||
tx3 = engine._physical_tx_connections[(id3.host, id3.port)]
|
||||
|
||||
assert list(engine._physical_tx_connections) == [(id1.host, id1.port), (id3.host, id3.port)]
|
||||
|
||||
await engine.disconnect_from(id2.host, id2.port)
|
||||
|
||||
tx12.close.assert_awaited_once()
|
||||
assert list(engine._physical_tx_connections) == [(id3.host, id3.port)]
|
||||
|
||||
await engine.disconnect_from(id3.host, id3.port)
|
||||
|
||||
tx3.close.assert_awaited_once()
|
||||
assert list(engine._physical_tx_connections) == []
|
||||
39
src/scatterbrained/network/types.py
Normal file
39
src/scatterbrained/network/types.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Optional, Protocol, Sequence, Tuple
|
||||
|
||||
|
||||
class RX(Protocol):
|
||||
@property
|
||||
def host(self) -> Optional[str]:
|
||||
...
|
||||
|
||||
@property
|
||||
def port(self) -> Optional[int]:
|
||||
...
|
||||
|
||||
async def bind(self, host: str, port: Optional[int] = None) -> int:
|
||||
...
|
||||
|
||||
async def recv(self) -> Tuple[str, Sequence[bytes]]:
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
...
|
||||
|
||||
|
||||
class TX(Protocol):
|
||||
@property
|
||||
def host(self) -> Optional[str]:
|
||||
...
|
||||
|
||||
@property
|
||||
def port(self) -> Optional[int]:
|
||||
...
|
||||
|
||||
async def connect(self, host: str, port: int) -> None:
|
||||
...
|
||||
|
||||
async def send(self, *segments: bytes) -> None:
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
...
|
||||
108
src/scatterbrained/network/zmq_ip.py
Normal file
108
src/scatterbrained/network/zmq_ip.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
|
||||
class ZMQReceiver:
|
||||
_context: zmq.asyncio.Context
|
||||
_socket: Optional[zmq.asyncio.Socket]
|
||||
_host: Optional[str]
|
||||
_port: Optional[int]
|
||||
|
||||
def __init__(self, context: Optional[zmq.asyncio.Context] = None):
|
||||
if context is None:
|
||||
context = zmq.asyncio.Context.instance()
|
||||
self._context = context
|
||||
self._socket = None
|
||||
self._host = None
|
||||
self._port = None
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
|
||||
async def bind(self, host: str, port: Optional[int] = None) -> int:
|
||||
if self._socket is not None:
|
||||
return self._port
|
||||
|
||||
self._socket = self._context.socket(zmq.ROUTER)
|
||||
if port is None:
|
||||
port = self._socket.bind_to_random_port(f"tcp://{host}")
|
||||
else:
|
||||
self._socket.bind(f"tcp://{host}:{port}")
|
||||
self._host = host
|
||||
self._port = port
|
||||
return self._port
|
||||
|
||||
async def recv(self) -> Tuple[str, Sequence[bytes]]:
|
||||
if self._socket is None:
|
||||
raise RuntimeError("socket is unbound")
|
||||
|
||||
peer_id, _, *segments = await self._socket.recv_multipart()
|
||||
return peer_id.decode(), segments
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._socket is None:
|
||||
return
|
||||
|
||||
self._socket.close(linger=0)
|
||||
self._socket = None
|
||||
self._host = None
|
||||
self._port = None
|
||||
|
||||
|
||||
class ZMQTransmitter:
|
||||
_id: str
|
||||
_context: zmq.asyncio.Context
|
||||
_socket: Optional[zmq.asyncio.Socket]
|
||||
_host: Optional[str]
|
||||
_port: Optional[int]
|
||||
|
||||
def __init__(self, id: str, context: Optional[zmq.asyncio.Context] = None):
|
||||
if context is None:
|
||||
context = zmq.asyncio.Context.instance()
|
||||
self._id = id
|
||||
self._context = context
|
||||
self._socket = None
|
||||
self._host = None
|
||||
self._port = None
|
||||
|
||||
@property
|
||||
def host(self) -> Optional[str]:
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self) -> Optional[int]:
|
||||
return self._port
|
||||
|
||||
async def connect(self, host: str, port: int) -> None:
|
||||
if self._socket is not None:
|
||||
return
|
||||
|
||||
self._socket = self._context.socket(zmq.DEALER)
|
||||
self._socket.setsockopt_string(zmq.IDENTITY, self._id)
|
||||
self._socket.connect(f"tcp://{host}:{port}")
|
||||
self._host = host
|
||||
self._port = port
|
||||
|
||||
async def send(self, *segments: bytes) -> None:
|
||||
if self._socket is None:
|
||||
raise RuntimeError("socket is not connected")
|
||||
await self._socket.send_multipart([b"", *segments])
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._socket is None:
|
||||
return
|
||||
|
||||
self._socket.close(linger=0)
|
||||
self._socket = None
|
||||
self._host = None
|
||||
self._port = None
|
||||
|
||||
|
||||
__all__ = ["ZMQReceiver", "ZMQTransmitter"]
|
||||
79
src/scatterbrained/network/zmq_ip_test.py
Normal file
79
src/scatterbrained/network/zmq_ip_test.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
from .zmq_ip import ZMQReceiver, ZMQTransmitter
|
||||
|
||||
# All test coroutines will be treated as marked.
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def addr():
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def port():
|
||||
return 9001
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context():
|
||||
ctx = zmq.asyncio.Context()
|
||||
yield ctx
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_id():
|
||||
return "foobar"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dealer(context, client_id, addr, port):
|
||||
socket = context.socket(zmq.DEALER)
|
||||
socket.identity = client_id.encode()
|
||||
socket.connect(f"tcp://{addr}:{port}")
|
||||
yield socket
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def router(context, addr, port):
|
||||
socket = context.socket(zmq.ROUTER)
|
||||
socket.bind(f"tcp://{addr}:{port}")
|
||||
yield socket
|
||||
socket.close(linger=0)
|
||||
|
||||
|
||||
async def test_ZMQReceiver_recv(context, dealer, client_id, addr, port):
|
||||
rx = ZMQReceiver(context=context)
|
||||
|
||||
await rx.bind(addr, port)
|
||||
|
||||
msg1, msg2 = b"hello there", b"general kenobi"
|
||||
await dealer.send_multipart([b"", msg1, msg2])
|
||||
|
||||
actual_id, actual_msg = await asyncio.wait_for(rx.recv(), timeout=3.0)
|
||||
|
||||
assert actual_id == client_id
|
||||
assert actual_msg == [msg1, msg2]
|
||||
|
||||
await rx.close()
|
||||
|
||||
|
||||
async def test_ZMQTransmitter_send(context, router, client_id, addr, port):
|
||||
tx = ZMQTransmitter(client_id, context=context)
|
||||
|
||||
await tx.connect(addr, port)
|
||||
|
||||
msg1, msg2 = b"hello there", b"general kenobi"
|
||||
await tx.send(msg1, msg2)
|
||||
|
||||
actual_id, _, *actual_segments = await router.recv_multipart()
|
||||
|
||||
assert actual_id.decode() == client_id
|
||||
assert actual_segments == [msg1, msg2]
|
||||
538
src/scatterbrained/node.py
Normal file
538
src/scatterbrained/node.py
Normal file
@@ -0,0 +1,538 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum
|
||||
from typing import Callable, Deque, Dict, Optional, Sequence, Set, Tuple, Union, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from .discovery import DiscoveryEngine, UDPBroadcaster, UDPReceiver
|
||||
from .network import NetworkEngine, ZMQReceiver, ZMQTransmitter
|
||||
from .types import Identity
|
||||
|
||||
|
||||
class OperatingMode(Enum):
|
||||
"""
|
||||
The mode for the Node to operate in.
|
||||
|
||||
Leech: Listen for broadcasts; do not share.
|
||||
Offline: Do not listen or broadcast.
|
||||
Peer: Listen and broadcast.
|
||||
Seeding: Broadcast but do not listen.
|
||||
|
||||
"""
|
||||
|
||||
LEECHING = "leeching"
|
||||
OFFLINE = "offline"
|
||||
PEER = "peer"
|
||||
SEEDING = "seeding"
|
||||
|
||||
|
||||
class Namespace:
|
||||
"""
|
||||
A Namespace is a collection of Scatterbrained nodes that share a common model or parameter state space. When
|
||||
connecting or disconnecting from a community, a Scatterbrained node joins a named namespace, or creates a new
|
||||
namespace with a unique name.
|
||||
|
||||
Note that this class shouldn't be instantiated directly, rather instances should created via the
|
||||
`scatterbrained.Node.namespace` method.
|
||||
|
||||
"""
|
||||
|
||||
node: Node
|
||||
name: str
|
||||
peer_filter: Callable[[Identity], bool]
|
||||
_operating_mode: OperatingMode
|
||||
_advertised_host: str
|
||||
_advertised_port: int
|
||||
_id: Identity
|
||||
_mq: Deque[Tuple[Identity, Sequence[bytes]]]
|
||||
_msg_cond: asyncio.Condition
|
||||
_peer_cond: asyncio.Condition
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node: Node,
|
||||
name: str,
|
||||
operating_mode: Optional[OperatingMode] = None,
|
||||
peer_filter: Optional[Callable[[Identity], bool]] = None,
|
||||
advertised_host: Optional[str] = None,
|
||||
advertised_port: Optional[int] = None,
|
||||
hwm: int = 10_000,
|
||||
):
|
||||
"""
|
||||
Create a new Namespace or point to an existing one.
|
||||
|
||||
Arguments:
|
||||
node (scatterbrained.Node): The `scatterbrained.Node` to which this Namespace belongs.
|
||||
name (str): The name of the `Namespace`. Must be unique if you are creating a new `Namespace`.
|
||||
operating_mode (scatterbrained.OperatingMode): The operating mode in which the `Node` should operate in this
|
||||
namespace. For more information on the different operating modes, see the docs on the
|
||||
`scatterbrained.node.OperatingMode` enum.
|
||||
peer_filter (callable): A function that takes an `Identity` and returns a boolean indicating whether the
|
||||
`Node` should connect to the peer.
|
||||
advertised_host (str, optional): The hostname or IP address to advertise to other Scatterbrained nodes.
|
||||
advertised_port (int, optional): The port to advertise to other nodes.
|
||||
hwm (int): The high water mark for the message queue. Old messages will be dropped if the queue is full
|
||||
beyond this number. Defaults to 10,000.
|
||||
|
||||
"""
|
||||
if not node.listening:
|
||||
raise RuntimeError(f"{type(node).__name__} must be listening before {type(self).__name__} creation")
|
||||
if advertised_host is None:
|
||||
advertised_host = node.default_advertised_host
|
||||
if advertised_port is None:
|
||||
advertised_port = node.default_advertised_port
|
||||
|
||||
assert advertised_host is not None # nosec
|
||||
assert advertised_port is not None # nosec
|
||||
|
||||
self.node = node
|
||||
self.name = name
|
||||
self.peer_filter = peer_filter or node.default_peer_filter
|
||||
self._operating_mode = operating_mode or node.default_operating_mode
|
||||
self._advertised_host = advertised_host
|
||||
self._advertised_port = advertised_port
|
||||
self._id = Identity(
|
||||
id=self.node.id, namespace=self.name, host=self._advertised_host, port=self._advertised_port
|
||||
)
|
||||
self._mq = deque(maxlen=hwm)
|
||||
self._msg_cond = asyncio.Condition()
|
||||
self._peer_cond = asyncio.Condition()
|
||||
|
||||
@property
|
||||
def operating_mode(self):
|
||||
"""
|
||||
Get the operating mode for this namespace.
|
||||
|
||||
Returns:
|
||||
scatterbrained.OperatingMode: The operating mode for this namespace.
|
||||
|
||||
"""
|
||||
return self._operating_mode
|
||||
|
||||
# NOTE: This may need to be an async setter method as opposed to a setter property.
|
||||
@operating_mode.setter
|
||||
def operating_mode(self, value: OperatingMode):
|
||||
self._operating_mode = value
|
||||
|
||||
async def __aenter__(self) -> Namespace:
|
||||
await self.launch()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
|
||||
async def launch(self):
|
||||
"""
|
||||
Launch the Namespace.
|
||||
|
||||
This is usually done by entering the `Namespace` context manager. You should not need to call this method
|
||||
directly.
|
||||
|
||||
"""
|
||||
self.node.discovery_engine.add_identity(self._id)
|
||||
current_peers = self.node.discovery_engine.peers
|
||||
# TODO: explore exception handling approaches. Is returning the exceptions really the best way to do this?
|
||||
# AFAIK, gather will cancel any ops running if an exception occurs, which would mean we'd only need to clean up.
|
||||
# We should clean up regardless?
|
||||
statuses = await asyncio.gather(*[self.connect_to(p) for p in current_peers], return_exceptions=True)
|
||||
|
||||
ex = False
|
||||
connected = []
|
||||
for peer, value in zip(current_peers, statuses):
|
||||
if isinstance(value, Exception):
|
||||
logger.bind(peer=peer).opt(exception=value).error(
|
||||
f"failed to connect to {peer.id} (addr={peer.host}:{peer.port})"
|
||||
)
|
||||
ex = True
|
||||
elif value:
|
||||
connected.append(peer)
|
||||
self.node._update_connections(self, connected)
|
||||
if ex:
|
||||
raise RuntimeError(
|
||||
f"one or more exceptions occurred when attempting to launch {type(self).__name__} '{self.name}'"
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Close out the Namespace gracefully.
|
||||
|
||||
This is usually done by exiting the `Namespace` context manager. You should not need to call this method
|
||||
directly.
|
||||
|
||||
"""
|
||||
self.node.discovery_engine.remove_identity(
|
||||
Identity(id=self.node.id, namespace=self.name, host=self._advertised_host, port=self._advertised_port)
|
||||
)
|
||||
|
||||
async def connect_to(self, peer: Identity) -> bool:
|
||||
"""
|
||||
Connect to a peer in this Namespace by its Identity.
|
||||
|
||||
The method will not connect to the given peer if the `OperatingMode` of this Namespace is either `LEECHING` or
|
||||
`OFFLINE`, or the `Identity` is rejected by this instance's `peer_filter`.
|
||||
|
||||
Arguments:
|
||||
peer (scatterbrained.Identity): The peer to connect to.
|
||||
|
||||
Returns:
|
||||
bool: `True` if the connection was successful, `False` otherwise.
|
||||
|
||||
"""
|
||||
if (
|
||||
self.operating_mode == OperatingMode.LEECHING
|
||||
or self.operating_mode == OperatingMode.OFFLINE
|
||||
or not self.peer_filter(peer)
|
||||
):
|
||||
return False
|
||||
|
||||
await self.node.network_engine.connect_to(peer.host, peer.port)
|
||||
return True
|
||||
|
||||
async def disconnect_from(self, peer: Identity) -> bool:
|
||||
"""
|
||||
Disconnect from a peer in this `Namespace` by its `Identity`.
|
||||
|
||||
Arguments:
|
||||
peer (scatterbrained.Identity): The peer to disconnect from.
|
||||
|
||||
Returns:
|
||||
bool: `True` if the disconnection was successful, `False` otherwise.
|
||||
|
||||
"""
|
||||
return await self.node._disconnect_from(peer)
|
||||
|
||||
async def wait_for_peers(
|
||||
self, peers: Union[int, str, Identity, Sequence[Identity], Sequence[str]], timeout: Optional[float] = None
|
||||
) -> None:
|
||||
"""
|
||||
Wait for peers to connect to this Namespace.
|
||||
|
||||
This method behaves in one of three ways depending on the type of arguments passed:
|
||||
|
||||
* If you pass one or more `Identity` objects, this will wait for all peers to connect. Importantly, this will
|
||||
match on all attributes of the `Identity` objects.
|
||||
* If you pass one or more strings, this will wait for all peers with the specified ids.
|
||||
* If you pass an integer, this will wait for that many peers to connect.
|
||||
|
||||
|
||||
Arguments:
|
||||
peers (Union[int, str, Identity, Sequence[Identity], Sequence[str]]): The condition to await.
|
||||
timeout (float): The timeout in seconds to wait for the condition to be met.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if isinstance(peers, int):
|
||||
return await asyncio.wait_for(self._wait_for_n_peers(peers), timeout=timeout)
|
||||
elif isinstance(peers, str):
|
||||
peers = [peers]
|
||||
elif isinstance(peers, Identity):
|
||||
peers = [peers]
|
||||
strings_present = any(isinstance(v, str) for v in peers)
|
||||
if strings_present:
|
||||
peers = [v if isinstance(v, str) else v.id for v in peers]
|
||||
await asyncio.wait_for(self._wait_for_named_peers(peers), timeout=timeout)
|
||||
else:
|
||||
peers = cast(Sequence[Identity], peers)
|
||||
await asyncio.wait_for(self._wait_for_exact_peers(peers), timeout=timeout)
|
||||
|
||||
async def _wait_for_n_peers(self, num_peers: int) -> None:
|
||||
def check_fn():
|
||||
return len(self.node._peers_by_namespace[self.name]) >= num_peers
|
||||
|
||||
if not check_fn():
|
||||
async with self._peer_cond:
|
||||
await self._peer_cond.wait_for(check_fn)
|
||||
|
||||
async def _wait_for_named_peers(self, peers: Sequence[str]) -> None:
|
||||
peer_set = set(peers)
|
||||
|
||||
def check_fn():
|
||||
return len(peer_set - set(v.id for v in self.node._peers_by_namespace[self.name])) == 0
|
||||
|
||||
if not check_fn():
|
||||
async with self._peer_cond:
|
||||
await self._peer_cond.wait_for(check_fn)
|
||||
|
||||
async def _wait_for_exact_peers(self, peers: Sequence[Identity]) -> None:
|
||||
peer_set = set(peers)
|
||||
|
||||
def check_fn():
|
||||
return len(peer_set - self.node._peers_by_namespace[self.name]) == 0
|
||||
|
||||
if not check_fn():
|
||||
async with self._peer_cond:
|
||||
await self._peer_cond.wait_for(check_fn)
|
||||
|
||||
async def send_to(self, peer: Identity, *payload: bytes) -> None:
|
||||
"""
|
||||
Send a byte sequence payload to a peer by its `Identity`.
|
||||
|
||||
Arguments:
|
||||
peer (scatterbrained.Identity): The peer to send to.
|
||||
payload (bytes): The payload to send.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
await self.node.network_engine.send_to(self._id, peer, *payload)
|
||||
|
||||
async def recv(self, timeout: Optional[float] = None) -> Tuple[Identity, Sequence[bytes]]:
|
||||
"""
|
||||
Receive a message from the network.
|
||||
|
||||
This will block until a message is received, or the timeout is reached.
|
||||
|
||||
Arguments:
|
||||
timeout (float): The timeout in seconds to wait for a message to be received.
|
||||
|
||||
Returns:
|
||||
Tuple[scatterbrained.Identity, bytes]: The sender's Identity and the payload.
|
||||
|
||||
"""
|
||||
if not len(self._mq):
|
||||
async with self._msg_cond:
|
||||
await asyncio.wait_for(self._msg_cond.wait(), timeout=timeout)
|
||||
return self._mq.pop()
|
||||
|
||||
async def _on_appear(self, peer: Identity):
|
||||
async with self._peer_cond:
|
||||
self._peer_cond.notify_all()
|
||||
|
||||
async def _add_message_from(self, peer: Identity, payload: Sequence[bytes]) -> None:
|
||||
empty = len(self._mq) == 0
|
||||
self._mq.append((peer, payload))
|
||||
|
||||
if empty:
|
||||
async with self._msg_cond:
|
||||
self._msg_cond.notify()
|
||||
|
||||
|
||||
class Node:
|
||||
"""
|
||||
Scatterbrained peer node.
|
||||
|
||||
Manages all networking infrastructure, providing an interface for other classes to establish connections, and send
|
||||
and receive data.
|
||||
"""
|
||||
|
||||
id: str
|
||||
discovery_engine: DiscoveryEngine
|
||||
network_engine: NetworkEngine
|
||||
default_operating_mode: OperatingMode
|
||||
default_peer_filter: Callable[[Identity], bool]
|
||||
default_advertised_host: Optional[str]
|
||||
default_advertised_port: Optional[int]
|
||||
_host: str
|
||||
_port: Optional[int]
|
||||
_namespaces: Dict[str, Namespace]
|
||||
_virtual_tx_connections: Dict[Tuple[str, int], Set[Identity]]
|
||||
_peers_by_namespace: Dict[str, Set[Identity]]
|
||||
_network_sub: Optional[NetworkEngine.Disposable]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
host: str = "0.0.0.0", # nosec
|
||||
port: Optional[int] = None,
|
||||
discovery_engine: Optional[DiscoveryEngine] = None,
|
||||
network_engine: Optional[NetworkEngine] = None,
|
||||
default_operating_mode: OperatingMode = OperatingMode.PEER,
|
||||
default_peer_filter: Optional[Callable[[Identity], bool]] = None,
|
||||
default_advertised_host: Optional[str] = None,
|
||||
default_advertised_port: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a new Scatterbrained peer node.
|
||||
|
||||
Arguments:
|
||||
id (str): The identity of this `Node`. Must be provided.
|
||||
host (str): The host to bind to. Defaults to bind on all IP addresses on the machine (`0.0.0.0`).
|
||||
port (int): The port to bind to. Defaults to an arbitrary open port.
|
||||
discovery_engine (scatterbrained.DiscoveryEngine): The `DiscoveryEngine` to use. Defaults to a new
|
||||
`DiscoveryEngine` with a default configuration, using UDP broadcast.
|
||||
network_engine (scatterbrained.NetworkEngine): The `NetworkEngine` to use. Defaults to a new NetworkEngine
|
||||
with a default configuration, using ZMQ.
|
||||
default_operating_mode (scatterbrained.OperatingMode): The default operating mode for new `Namespace`s.
|
||||
Defaults to `OperatingMode.PEER`.
|
||||
default_peer_filter (Callable[[scatterbrained.Identity], bool]): The default peer filter for new
|
||||
`Namespace`s. The default behavior is for all peers to be accepted.
|
||||
default_advertised_host (str): The advertised host to use for new `Namespace`s. Defaults to the host
|
||||
provided.
|
||||
default_advertised_port (int): The advertised port to use for new `Namespace`s. Defaults to the port
|
||||
provided.
|
||||
|
||||
"""
|
||||
if discovery_engine is None:
|
||||
discovery_engine = DiscoveryEngine(publisher=UDPBroadcaster(), subscriber=UDPReceiver())
|
||||
if network_engine is None:
|
||||
network_engine = NetworkEngine(rx=ZMQReceiver(), tx_factory=lambda: ZMQTransmitter(id=id))
|
||||
if default_peer_filter is None:
|
||||
default_peer_filter = lambda _: True # noqa: E731
|
||||
|
||||
self.id = id
|
||||
self.discovery_engine = discovery_engine
|
||||
self.network_engine = network_engine
|
||||
self.default_peer_filter = default_peer_filter
|
||||
self.default_operating_mode = default_operating_mode
|
||||
self.default_advertised_host = default_advertised_host
|
||||
self.default_advertised_port = default_advertised_port
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._namespaces = {}
|
||||
self._virtual_tx_connections = defaultdict(set)
|
||||
self._peers_by_namespace = defaultdict(set)
|
||||
self._network_sub = None
|
||||
|
||||
@property
|
||||
def listening(self):
|
||||
"""
|
||||
Whether or not the `Node` is listening for incoming connections.
|
||||
|
||||
Returns:
|
||||
bool: Whether or not the Node is listening for incoming connections.
|
||||
|
||||
"""
|
||||
return self._network_sub is not None
|
||||
|
||||
async def __aenter__(self) -> Node:
|
||||
await self.launch()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
await self.close()
|
||||
|
||||
async def launch(self) -> None:
|
||||
"""
|
||||
Launch the `Node`.
|
||||
|
||||
This will start the `DiscoveryEngine` and `NetworkEngine` and begin listening for new connections.
|
||||
|
||||
Note that this method is usually called automatically when entering an `async with` block and is not meant to
|
||||
be called manually.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if self._network_sub is not None:
|
||||
return
|
||||
await self.network_engine.bind(host=self._host, port=self._port)
|
||||
self._network_sub = self.network_engine.subscribe(
|
||||
on_recv=self._on_recv, on_malformed=self._on_malformed, on_error=self._on_network_error
|
||||
)
|
||||
await self.discovery_engine.start(
|
||||
on_appear=self._on_appear, on_disappear=self._on_disappear, on_error=self._on_discovery_error
|
||||
)
|
||||
|
||||
bound_host, bound_port = self.network_engine.bound_address
|
||||
assert bound_host is not None, "expected bound host to not be None" # nosec
|
||||
assert bound_port is not None, "expected bound port to not be None" # nosec
|
||||
|
||||
if self.default_advertised_host is None:
|
||||
self.default_advertised_host = bound_host
|
||||
if self.default_advertised_port is None:
|
||||
self.default_advertised_port = bound_port
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Gracefully close the `Node`.
|
||||
|
||||
This will close the `DiscoveryEngine` and `NetworkEngine`.
|
||||
|
||||
Note that this method is usually called automatically when exiting an `async with` block and is not meant to be
|
||||
called manually.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
if self._network_sub is None:
|
||||
return
|
||||
await self.discovery_engine.stop()
|
||||
await self._network_sub.dispose()
|
||||
await self.network_engine.close()
|
||||
self._network_sub = None
|
||||
|
||||
def namespace(self, name: str, *args, **kwargs) -> Namespace:
|
||||
"""
|
||||
Gets the namespace with the given name, or creates a new one.
|
||||
|
||||
Also accepts the same arguments as `scatterbrained.Namespace` if creating a new namespace.
|
||||
|
||||
Arguments:
|
||||
name (str): The name of the namespace.
|
||||
|
||||
Returns:
|
||||
scatterbrained.Namespace: The namespace with the given name.
|
||||
|
||||
"""
|
||||
if (ns := self._namespaces.get(name)) is not None:
|
||||
return ns
|
||||
|
||||
ns = Namespace(self, name, *args, **kwargs)
|
||||
self._namespaces[ns.name] = ns
|
||||
return ns
|
||||
|
||||
def _update_connections(self, ns: Namespace, peers: Sequence[Identity]) -> None:
|
||||
for peer in peers:
|
||||
self._virtual_tx_connections[(peer.host, peer.port)].add(peer)
|
||||
|
||||
async def _on_appear(self, peer: Identity) -> None:
|
||||
cl = logger.bind(peer=dataclasses.asdict(peer))
|
||||
cl.debug(f"peer '{peer.id}' is online")
|
||||
|
||||
ns = self._namespaces.get(peer.namespace)
|
||||
if ns is not None:
|
||||
connected = await ns.connect_to(peer)
|
||||
if connected:
|
||||
cl.debug(f"connected to '{peer.id}'")
|
||||
|
||||
self._virtual_tx_connections[(peer.host, peer.port)].add(peer)
|
||||
self._peers_by_namespace[ns.name].add(peer)
|
||||
await ns._on_appear(peer)
|
||||
|
||||
async def _on_disappear(self, peer: Identity) -> None:
|
||||
await self._disconnect_from(peer)
|
||||
|
||||
async def _on_discovery_error(self, ex: Exception) -> None:
|
||||
logger.opt(exception=ex).error("discovery engine encountered an error")
|
||||
# TODO: determine how to best alert the user of this error.
|
||||
|
||||
async def _on_recv(self, peer: Identity, payload: Sequence[bytes]) -> None:
|
||||
ns = self._namespaces.get(peer.namespace)
|
||||
if ns is not None:
|
||||
await ns._add_message_from(peer, payload)
|
||||
|
||||
async def _on_malformed(self, peer_id: str, segments: Sequence[bytes]) -> None:
|
||||
logger.warning(f"malformed message consisting of {len(segments)} byte segments from '{peer_id}'")
|
||||
# TODO: anything else to do here? Maybe allow segments to be saved.
|
||||
|
||||
async def _on_network_error(self, ex: Exception) -> None:
|
||||
logger.opt(exception=ex).error("network engine encountered an error")
|
||||
# TODO: determine how to best alert the user of this error.
|
||||
|
||||
async def _disconnect_from(self, peer: Identity) -> bool:
|
||||
cl = logger.bind(peer=dataclasses.asdict(peer))
|
||||
cl.debug(f"peer '{peer.id}' is offline")
|
||||
|
||||
virt_conns = self._virtual_tx_connections[(peer.host, peer.port)]
|
||||
if peer not in virt_conns:
|
||||
cl.warning(f"attempted to disconnect from peer '{peer.id}' that was never connected to originally")
|
||||
return False
|
||||
virt_conns.remove(peer)
|
||||
if not virt_conns:
|
||||
await self.network_engine.disconnect_from(peer.host, peer.port)
|
||||
|
||||
cl.debug(f"disconnected from '{peer.host}:{peer.port}'")
|
||||
self._peers_by_namespace[peer.namespace].remove(peer)
|
||||
return True
|
||||
|
||||
|
||||
__all__ = ["Node"]
|
||||
23
src/scatterbrained/types.py
Normal file
23
src/scatterbrained/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Identity:
|
||||
"""
|
||||
Represents the identity of a scatterbrained.Node (either local or remote) for a particular namespace.
|
||||
|
||||
Args:
|
||||
id (str): The id of the Node.
|
||||
namespace (str): The namespace the Node is operating in.
|
||||
host (str): The advertised address of this Node.
|
||||
port (int): The advertised port of this Node.
|
||||
|
||||
"""
|
||||
|
||||
id: str
|
||||
namespace: str
|
||||
host: str
|
||||
port: int
|
||||
|
||||
|
||||
__all__ = ["Identity"]
|
||||
1
src/scatterbrained/version.py
Normal file
1
src/scatterbrained/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.0.1"
|
||||
Reference in New Issue
Block a user