mirror of
https://github.com/rembo10/headphones.git
synced 2026-01-08 22:38:08 -05:00
Merge branch 'develop'
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@@ -1,7 +1,17 @@
|
||||
# Changelog
|
||||
|
||||
## v0.6.2
|
||||
Released 26 May 2024
|
||||
|
||||
Highlights:
|
||||
* Added soulseek support
|
||||
* Added bandcamp support
|
||||
* Changes and dependency updates to work with Python >= 3.12
|
||||
|
||||
The full list of commits can be found [here](https://github.com/rembo10/headphones/compare/v0.6.1...v0.6.2).
|
||||
|
||||
## v0.6.1
|
||||
Released 26 November 2023
|
||||
R eleased 26 November 2023
|
||||
|
||||
Highlights:
|
||||
* Dependency updates to work with > Python 3.11
|
||||
|
||||
@@ -310,6 +310,16 @@
|
||||
<input type="text" name="usenet_retention" value="${config['usenet_retention']}" size="5">
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset title="Method for downloading Bandcamp.com files.">
|
||||
<legend>Bandcamp</legend>
|
||||
<div class="row">
|
||||
<label title="Path to folder where Headphones can store raw downloads from Bandcamp.com.">
|
||||
Bandcamp Directory
|
||||
</label>
|
||||
<input type="text" name="bandcamp_dir" value="${config['bandcamp_dir']}" size="50">
|
||||
<small>Full path where raw MP3s will be stored, e.g. /Users/name/Downloads/bandcamp</small>
|
||||
</div>
|
||||
</fieldset>
|
||||
</td>
|
||||
<td>
|
||||
<fieldset title="Method for downloading torrent files.">
|
||||
@@ -317,7 +327,7 @@
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_blackhole" value="0" ${config['torrent_downloader_blackhole']}> Black Hole
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_transmission" value="1" ${config['torrent_downloader_transmission']}> Transmission
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_utorrent" value="2" ${config['torrent_downloader_utorrent']}> uTorrent (Beta)
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_deluge" value="3" ${config['torrent_downloader_deluge']}> Deluge (Beta)
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_deluge" value="3" ${config['torrent_downloader_deluge']}> Deluge
|
||||
<input type="radio" name="torrent_downloader" id="torrent_downloader_qbittorrent" value="4" ${config['torrent_downloader_qbittorrent']}> QBitTorrent
|
||||
</fieldset>
|
||||
<fieldset id="torrent_blackhole_options">
|
||||
@@ -438,6 +448,11 @@
|
||||
<input type="text" name="deluge_label" value="${config['deluge_label']}" size="30">
|
||||
<small>Labels shouldn't contain spaces (requires Label plugin)</small>
|
||||
</div>
|
||||
<div class="row">
|
||||
<label>Download Directory</label>
|
||||
<input type="text" name="deluge_download_directory" value="${config['deluge_download_directory']}" size="30">
|
||||
<small>Directory where Deluge should download to</small>
|
||||
</div>
|
||||
<div class="row">
|
||||
<label>Move When Completed</label>
|
||||
<input type="text" name="deluge_done_directory" value="${config['deluge_done_directory']}" size="30">
|
||||
@@ -467,7 +482,33 @@
|
||||
<label>Prefer</label>
|
||||
<input type="radio" name="prefer_torrents" id="prefer_torrents_0" value="0" ${config['prefer_torrents_0']}>NZBs
|
||||
<input type="radio" name="prefer_torrents" id="prefer_torrents_1" value="1" ${config['prefer_torrents_1']}>Torrents
|
||||
<input type="radio" name="prefer_torrents" id="prefer_torrents_2" value="2" ${config['prefer_torrents_2']}>No Preference
|
||||
<input type="radio" name="prefer_torrents" id="prefer_torrents_2" value="2" ${config['prefer_torrents_2']}>Soulseek
|
||||
<input type="radio" name="prefer_torrents" id="prefer_torrents_3" value="3" ${config['prefer_torrents_3']}>No Preference
|
||||
</div>
|
||||
</fieldset>
|
||||
</td>
|
||||
<td>
|
||||
<fieldset>
|
||||
<legend>Soulseek</legend>
|
||||
<div class="row">
|
||||
<label>Soulseek API URL</label>
|
||||
<input type="text" name="soulseek_api_url" value="${config['soulseek_api_url']}" size="50">
|
||||
</div>
|
||||
<div class="row">
|
||||
<label>Soulseek API KEY</label>
|
||||
<input type="text" name="soulseek_api_key" value="${config['soulseek_api_key']}" size="20">
|
||||
</div>
|
||||
<div class="row">
|
||||
<label title="Path to folder where Headphones can find the downloads.">
|
||||
Soulseek Download Dir:
|
||||
</label>
|
||||
<input type="text" name="soulseek_download_dir" value="${config['soulseek_download_dir']}" size="50">
|
||||
</div>
|
||||
<div class="row">
|
||||
<label title="Path to folder where Headphones can find the downloads.">
|
||||
Soulseek Incomplete Download Dir:
|
||||
</label>
|
||||
<input type="text" name="soulseek_incomplete_download_dir" value="${config['soulseek_incomplete_download_dir']}" size="50">
|
||||
</div>
|
||||
</fieldset>
|
||||
</td>
|
||||
@@ -579,6 +620,19 @@
|
||||
</div>
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<legend>Other</legend>
|
||||
<fieldset>
|
||||
<div class="row checkbox left">
|
||||
<input id="use_bandcamp" type="checkbox" class="bigcheck" name="use_bandcamp" value="1" ${config['use_bandcamp']} /><label for="use_bandcamp"><span class="option">Bandcamp</span></label>
|
||||
</div>
|
||||
</fieldset>
|
||||
<fieldset>
|
||||
<div class="row checkbox left">
|
||||
<input id="use_soulseek" type="checkbox" class="bigcheck" name="use_soulseek" value="1" ${config['use_soulseek']} /><label for="use_soulseek"><span class="option">Soulseek</span></label>
|
||||
</div>
|
||||
</fieldset>
|
||||
</fieldset>
|
||||
</td>
|
||||
<td>
|
||||
<fieldset>
|
||||
|
||||
@@ -56,6 +56,8 @@
|
||||
fileid = 'torrent'
|
||||
if item['URL'].find('codeshy') != -1:
|
||||
fileid = 'nzb'
|
||||
if item['URL'].find('bandcamp') != -1:
|
||||
fileid = 'bandcamp'
|
||||
|
||||
folder = 'Folder: ' + item['FolderName']
|
||||
|
||||
|
||||
166
headphones/bandcamp.py
Normal file
166
headphones/bandcamp.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# This file is part of Headphones.
|
||||
#
|
||||
# Headphones is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# Headphones is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with Headphones. If not, see <http://www.gnu.org/licenses/>
|
||||
|
||||
import headphones
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from headphones import logger, helpers, metadata, request
|
||||
from headphones.common import USER_AGENT
|
||||
from headphones.types import Result
|
||||
|
||||
from mediafile import MediaFile, UnreadableFileError
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4 import FeatureNotFound
|
||||
|
||||
|
||||
def search(album, albumlength=None, page=1, resultlist=None):
|
||||
dic = {'...': '', ' & ': ' ', ' = ': ' ', '?': '', '$': 's', ' + ': ' ',
|
||||
'"': '', ',': '', '*': '', '.': '', ':': ''}
|
||||
if resultlist is None:
|
||||
resultlist = []
|
||||
|
||||
cleanalbum = helpers.latinToAscii(
|
||||
helpers.replace_all(album['AlbumTitle'], dic)
|
||||
).strip()
|
||||
cleanartist = helpers.latinToAscii(
|
||||
helpers.replace_all(album['ArtistName'], dic)
|
||||
).strip()
|
||||
|
||||
headers = {'User-Agent': USER_AGENT}
|
||||
params = {
|
||||
"page": page,
|
||||
"q": cleanalbum,
|
||||
}
|
||||
logger.info("Looking up https://bandcamp.com/search with {}".format(
|
||||
params))
|
||||
content = request.request_content(
|
||||
url='https://bandcamp.com/search',
|
||||
params=params,
|
||||
headers=headers
|
||||
).decode('utf8')
|
||||
try:
|
||||
soup = BeautifulSoup(content, "html5lib")
|
||||
except FeatureNotFound:
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
for item in soup.find_all("li", class_="searchresult"):
|
||||
type = item.find('div', class_='itemtype').text.strip().lower()
|
||||
if type == "album":
|
||||
data = parse_album(item)
|
||||
|
||||
cleanartist_found = helpers.latinToAscii(data['artist'])
|
||||
cleanalbum_found = helpers.latinToAscii(data['album'])
|
||||
|
||||
logger.debug(u"{} - {}".format(data['album'], cleanalbum_found))
|
||||
|
||||
logger.debug("Comparing {} to {}".format(
|
||||
cleanalbum, cleanalbum_found))
|
||||
if (cleanartist.lower() == cleanartist_found.lower() and
|
||||
cleanalbum.lower() == cleanalbum_found.lower()):
|
||||
resultlist.append(Result(
|
||||
data['title'], data['size'], data['url'],
|
||||
'bandcamp', 'bandcamp', True))
|
||||
else:
|
||||
continue
|
||||
|
||||
if(soup.find('a', class_='next')):
|
||||
page += 1
|
||||
logger.debug("Calling next page ({})".format(page))
|
||||
search(album, albumlength=albumlength,
|
||||
page=page, resultlist=resultlist)
|
||||
|
||||
return resultlist
|
||||
|
||||
|
||||
def download(album, bestqual):
|
||||
html = request.request_content(url=bestqual.url).decode('utf-8')
|
||||
trackinfo = []
|
||||
try:
|
||||
trackinfo = json.loads(
|
||||
re.search(r"trackinfo":(\[.*?\]),", html)
|
||||
.group(1)
|
||||
.replace('"', '"'))
|
||||
except ValueError as e:
|
||||
logger.warn("Couldn't load json: {}".format(e))
|
||||
|
||||
directory = os.path.join(
|
||||
headphones.CONFIG.BANDCAMP_DIR,
|
||||
u'{} - {}'.format(
|
||||
album['ArtistName'].replace('/', '_'),
|
||||
album['AlbumTitle'].replace('/', '_')))
|
||||
directory = helpers.latinToAscii(directory)
|
||||
|
||||
if not os.path.exists(directory):
|
||||
try:
|
||||
os.makedirs(directory)
|
||||
except Exception as e:
|
||||
logger.warn("Could not create directory ({})".format(e))
|
||||
|
||||
index = 1
|
||||
for track in trackinfo:
|
||||
filename = helpers.replace_illegal_chars(
|
||||
u'{:02d} - {}.mp3'.format(index, track['title']))
|
||||
fullname = os.path.join(directory.encode('utf-8'),
|
||||
filename.encode('utf-8'))
|
||||
logger.debug("Downloading to {}".format(fullname))
|
||||
|
||||
if 'file' in track and track['file'] != None and 'mp3-128' in track['file']:
|
||||
content = request.request_content(track['file']['mp3-128'])
|
||||
open(fullname, 'wb').write(content)
|
||||
try:
|
||||
f = MediaFile(fullname)
|
||||
date, year = metadata._date_year(album)
|
||||
f.update({
|
||||
'artist': album['ArtistName'].encode('utf-8'),
|
||||
'album': album['AlbumTitle'].encode('utf-8'),
|
||||
'title': track['title'].encode('utf-8'),
|
||||
'track': track['track_num'],
|
||||
'tracktotal': len(trackinfo),
|
||||
'year': year,
|
||||
})
|
||||
f.save()
|
||||
except UnreadableFileError as ex:
|
||||
logger.warn("MediaFile couldn't parse: %s (%s)",
|
||||
fullname,
|
||||
str(ex))
|
||||
|
||||
index += 1
|
||||
|
||||
return directory
|
||||
|
||||
|
||||
def parse_album(item):
|
||||
album = item.find('div', class_='heading').text.strip()
|
||||
artist = item.find('div', class_='subhead').text.strip().replace("by ", "")
|
||||
released = item.find('div', class_='released').text.strip().replace(
|
||||
"released ", "")
|
||||
year = re.search(r"(\d{4})", released).group(1)
|
||||
|
||||
url = item.find('div', class_='heading').find('a')['href'].split("?")[0]
|
||||
|
||||
length = item.find('div', class_='length').text.strip()
|
||||
tracks, minutes = length.split(",")
|
||||
tracks = tracks.replace(" tracks", "").replace(" track", "").strip()
|
||||
minutes = minutes.replace(" minutes", "").strip()
|
||||
# bandcamp offers mp3 128b with should be 960KB/minute
|
||||
size = int(minutes) * 983040
|
||||
|
||||
data = {"title": u'{} - {} [{}]'.format(artist, album, year),
|
||||
"artist": artist, "album": album,
|
||||
"url": url, "size": size}
|
||||
|
||||
return data
|
||||
@@ -102,36 +102,6 @@ class Quality:
|
||||
|
||||
return (anyQualities, bestQualities)
|
||||
|
||||
@staticmethod
|
||||
def nameQuality(name):
|
||||
|
||||
def checkName(list, func):
|
||||
return func([re.search(x, name, re.I) for x in list])
|
||||
|
||||
name = os.path.basename(name)
|
||||
|
||||
# if we have our exact text then assume we put it there
|
||||
for x in Quality.qualityStrings:
|
||||
if x == Quality.UNKNOWN:
|
||||
continue
|
||||
|
||||
regex = '\W' + Quality.qualityStrings[x].replace(' ', '\W') + '\W'
|
||||
regex_match = re.search(regex, name, re.I)
|
||||
if regex_match:
|
||||
return x
|
||||
|
||||
# TODO: fix quality checking here
|
||||
if checkName(["mp3", "192"], any) and not checkName(["flac"], all):
|
||||
return Quality.B192
|
||||
elif checkName(["mp3", "256"], any) and not checkName(["flac"], all):
|
||||
return Quality.B256
|
||||
elif checkName(["mp3", "vbr"], any) and not checkName(["flac"], all):
|
||||
return Quality.VBR
|
||||
elif checkName(["mp3", "320"], any) and not checkName(["flac"], all):
|
||||
return Quality.B320
|
||||
else:
|
||||
return Quality.UNKNOWN
|
||||
|
||||
@staticmethod
|
||||
def assumeQuality(name):
|
||||
if name.lower().endswith(".mp3"):
|
||||
@@ -158,13 +128,6 @@ class Quality:
|
||||
|
||||
return (Quality.NONE, status)
|
||||
|
||||
@staticmethod
|
||||
def statusFromName(name, assume=True):
|
||||
quality = Quality.nameQuality(name)
|
||||
if assume and quality == Quality.UNKNOWN:
|
||||
quality = Quality.assumeQuality(name)
|
||||
return Quality.compositeStatus(DOWNLOADED, quality)
|
||||
|
||||
DOWNLOADED = None
|
||||
SNATCHED = None
|
||||
SNATCHED_PROPER = None
|
||||
|
||||
@@ -80,6 +80,7 @@ _CONFIG_DEFINITIONS = {
|
||||
'DELUGE_PASSWORD': (str, 'Deluge', ''),
|
||||
'DELUGE_LABEL': (str, 'Deluge', ''),
|
||||
'DELUGE_DONE_DIRECTORY': (str, 'Deluge', ''),
|
||||
'DELUGE_DOWNLOAD_DIRECTORY': (str, 'Deluge', ''),
|
||||
'DELUGE_PAUSED': (int, 'Deluge', 0),
|
||||
'DESTINATION_DIR': (str, 'General', ''),
|
||||
'DETECT_BITRATE': (int, 'General', 0),
|
||||
@@ -269,6 +270,11 @@ _CONFIG_DEFINITIONS = {
|
||||
'SONGKICK_ENABLED': (int, 'Songkick', 1),
|
||||
'SONGKICK_FILTER_ENABLED': (int, 'Songkick', 0),
|
||||
'SONGKICK_LOCATION': (str, 'Songkick', ''),
|
||||
'SOULSEEK_API_URL': (str, 'Soulseek', ''),
|
||||
'SOULSEEK_API_KEY': (str, 'Soulseek', ''),
|
||||
'SOULSEEK_DOWNLOAD_DIR': (str, 'Soulseek', ''),
|
||||
'SOULSEEK_INCOMPLETE_DOWNLOAD_DIR': (str, 'Soulseek', ''),
|
||||
'SOULSEEK': (int, 'Soulseek', 0),
|
||||
'SUBSONIC_ENABLED': (int, 'Subsonic', 0),
|
||||
'SUBSONIC_HOST': (str, 'Subsonic', ''),
|
||||
'SUBSONIC_PASSWORD': (str, 'Subsonic', ''),
|
||||
@@ -317,7 +323,9 @@ _CONFIG_DEFINITIONS = {
|
||||
'XBMC_PASSWORD': (str, 'XBMC', ''),
|
||||
'XBMC_UPDATE': (int, 'XBMC', 0),
|
||||
'XBMC_USERNAME': (str, 'XBMC', ''),
|
||||
'XLDPROFILE': (str, 'General', '')
|
||||
'XLDPROFILE': (str, 'General', ''),
|
||||
'BANDCAMP': (int, 'General', 0),
|
||||
'BANDCAMP_DIR': (path, 'General', '')
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -58,12 +58,12 @@ def _scrubber(text):
|
||||
if scrub_logs:
|
||||
try:
|
||||
# URL parameter values
|
||||
text = re.sub('=[0-9a-zA-Z]*', '=REMOVED', text)
|
||||
text = re.sub(r'=[0-9a-zA-Z]*', r'=REMOVED', text)
|
||||
# Local host with port
|
||||
# text = re.sub('\:\/\/.*\:', '://REMOVED:', text) # just host
|
||||
text = re.sub('\:\/\/.*\:[0-9]*', '://REMOVED:', text)
|
||||
text = re.sub(r'\:\/\/.*\:[0-9]*', r'://REMOVED:', text)
|
||||
# Session cookie
|
||||
text = re.sub("_session_id'\: '.*'", "_session_id': 'REMOVED'", text)
|
||||
text = re.sub(r"_session_id'\: '.*'", r"_session_id': 'REMOVED'", text)
|
||||
# Local Windows user path
|
||||
if text.lower().startswith('c:\\users\\'):
|
||||
k = text.split('\\')
|
||||
@@ -128,9 +128,9 @@ def addTorrent(link, data=None, name=None):
|
||||
# Extract torrent name from .torrent
|
||||
try:
|
||||
logger.debug('Deluge: Getting torrent name length')
|
||||
name_length = int(re.findall('name([0-9]*)\:.*?\:', str(torrentfile))[0])
|
||||
name_length = int(re.findall(r'name([0-9]*)\:.*?\:', str(torrentfile))[0])
|
||||
logger.debug('Deluge: Getting torrent name')
|
||||
name = re.findall('name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
|
||||
name = re.findall(r'name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
|
||||
except Exception as e:
|
||||
logger.debug('Deluge: Could not get torrent name, getting file name')
|
||||
# get last part of link/path (name only)
|
||||
@@ -160,9 +160,9 @@ def addTorrent(link, data=None, name=None):
|
||||
# Extract torrent name from .torrent
|
||||
try:
|
||||
logger.debug('Deluge: Getting torrent name length')
|
||||
name_length = int(re.findall('name([0-9]*)\:.*?\:', str(torrentfile))[0])
|
||||
name_length = int(re.findall(r'name([0-9]*)\:.*?\:', str(torrentfile))[0])
|
||||
logger.debug('Deluge: Getting torrent name')
|
||||
name = re.findall('name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
|
||||
name = re.findall(r'name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
|
||||
except Exception as e:
|
||||
logger.debug('Deluge: Could not get torrent name, getting file name')
|
||||
# get last part of link/path (name only)
|
||||
@@ -466,19 +466,56 @@ def _add_torrent_url(result):
|
||||
|
||||
def _add_torrent_file(result):
|
||||
logger.debug('Deluge: Adding file')
|
||||
|
||||
options = {}
|
||||
|
||||
if headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY:
|
||||
options['download_location'] = headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY
|
||||
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
|
||||
options['move_completed'] = 1
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY:
|
||||
options['move_completed_path'] = headphones.CONFIG.DELUGE_DONE_DIRECTORY
|
||||
else:
|
||||
options['move_completed_path'] = headphones.CONFIG.DOWNLOAD_TORRENT_DIR
|
||||
|
||||
if headphones.CONFIG.DELUGE_PAUSED:
|
||||
options['add_paused'] = headphones.CONFIG.DELUGE_PAUSED
|
||||
|
||||
if not any(delugeweb_auth):
|
||||
_get_auth()
|
||||
try:
|
||||
# content is torrent file contents that needs to be encoded to base64
|
||||
post_data = json.dumps({"method": "core.add_torrent_file",
|
||||
"params": [result['name'] + '.torrent',
|
||||
b64encode(result['content']).decode(), {}],
|
||||
b64encode(result['content'].encode('utf8')),
|
||||
options],
|
||||
"id": 2})
|
||||
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
|
||||
verify=deluge_verify_cert, headers=headers)
|
||||
result['hash'] = json.loads(response.text)['result']
|
||||
logger.debug('Deluge: Response was %s' % str(json.loads(response.text)))
|
||||
return json.loads(response.text)['result']
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
# content is torrent file contents that needs to be encoded to base64
|
||||
# this time let's try leaving the encoding as is
|
||||
logger.debug('Deluge: There was a decoding issue, let\'s try again')
|
||||
post_data = json.dumps({"method": "core.add_torrent_file",
|
||||
"params": [result['name'].decode('utf8') + '.torrent',
|
||||
b64encode(result['content']),
|
||||
options],
|
||||
"id": 22})
|
||||
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
|
||||
verify=deluge_verify_cert, headers=headers)
|
||||
result['hash'] = json.loads(response.text)['result']
|
||||
logger.debug('Deluge: Response was %s' % str(json.loads(response.text)))
|
||||
return json.loads(response.text)['result']
|
||||
except Exception as e:
|
||||
logger.error('Deluge: Adding torrent file failed after decode: %s' % str(e))
|
||||
formatted_lines = traceback.format_exc().splitlines()
|
||||
logger.error('; '.join(formatted_lines))
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error('Deluge: Adding torrent file failed: %s' % str(e))
|
||||
formatted_lines = traceback.format_exc().splitlines()
|
||||
@@ -566,61 +603,3 @@ def setSeedRatio(result):
|
||||
return None
|
||||
|
||||
|
||||
def setTorrentPath(result):
|
||||
logger.debug('Deluge: Setting download path')
|
||||
if not any(delugeweb_auth):
|
||||
_get_auth()
|
||||
|
||||
try:
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
|
||||
post_data = json.dumps({"method": "core.set_torrent_move_completed",
|
||||
"params": [result['hash'], True],
|
||||
"id": 7})
|
||||
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
|
||||
verify=deluge_verify_cert, headers=headers)
|
||||
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY:
|
||||
move_to = headphones.CONFIG.DELUGE_DONE_DIRECTORY
|
||||
else:
|
||||
move_to = headphones.CONFIG.DOWNLOAD_TORRENT_DIR
|
||||
|
||||
if not os.path.exists(move_to):
|
||||
logger.debug('Deluge: %s directory doesn\'t exist, let\'s create it' % move_to)
|
||||
os.makedirs(move_to)
|
||||
post_data = json.dumps({"method": "core.set_torrent_move_completed_path",
|
||||
"params": [result['hash'], move_to],
|
||||
"id": 8})
|
||||
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
|
||||
verify=deluge_verify_cert, headers=headers)
|
||||
|
||||
return not json.loads(response.text)['error']
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error('Deluge: Setting torrent move-to directory failed: %s' % str(e))
|
||||
formatted_lines = traceback.format_exc().splitlines()
|
||||
logger.error('; '.join(formatted_lines))
|
||||
return None
|
||||
|
||||
|
||||
def setTorrentPause(result):
|
||||
logger.debug('Deluge: Pausing torrent')
|
||||
if not any(delugeweb_auth):
|
||||
_get_auth()
|
||||
|
||||
try:
|
||||
if headphones.CONFIG.DELUGE_PAUSED:
|
||||
post_data = json.dumps({"method": "core.pause_torrent",
|
||||
"params": [[result['hash']]],
|
||||
"id": 9})
|
||||
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
|
||||
verify=deluge_verify_cert, headers=headers)
|
||||
|
||||
return not json.loads(response.text)['error']
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error('Deluge: Setting torrent paused failed: %s' % str(e))
|
||||
formatted_lines = traceback.format_exc().splitlines()
|
||||
logger.error('; '.join(formatted_lines))
|
||||
return None
|
||||
|
||||
@@ -184,7 +184,7 @@ def bytes_to_mb(bytes):
|
||||
|
||||
|
||||
def mb_to_bytes(mb_str):
|
||||
result = re.search('^(\d+(?:\.\d+)?)\s?(?:mb)?', mb_str, flags=re.I)
|
||||
result = re.search(r"^(\d+(?:\.\d+)?)\s?(?:mb)?", mb_str, flags=re.I)
|
||||
if result:
|
||||
return int(float(result.group(1)) * 1048576)
|
||||
|
||||
@@ -253,9 +253,9 @@ def replace_all(text, dic):
|
||||
|
||||
def replace_illegal_chars(string, type="file"):
|
||||
if type == "file":
|
||||
string = re.sub('[\?"*:|<>/]', '_', string)
|
||||
string = re.sub(r"[\?\"*:|<>/]", "_", string)
|
||||
if type == "folder":
|
||||
string = re.sub('[:\?<>"|*]', '_', string)
|
||||
string = re.sub(r"[:\?<>\"|*]", "_", string)
|
||||
return string
|
||||
|
||||
|
||||
@@ -386,7 +386,7 @@ def clean_musicbrainz_name(s, return_as_string=True):
|
||||
|
||||
|
||||
def cleanTitle(title):
|
||||
title = re.sub('[\.\-\/\_]', ' ', title).lower()
|
||||
title = re.sub(r"[\.\-\/\_]", " ", title).lower()
|
||||
|
||||
# Strip out extra whitespace
|
||||
title = ' '.join(title.split())
|
||||
@@ -1050,3 +1050,10 @@ def have_pct_have_total(db_artist):
|
||||
have_pct = have_tracks / total_tracks if total_tracks else 0
|
||||
return (have_pct, total_tracks)
|
||||
|
||||
|
||||
def has_token(title, token):
|
||||
return bool(
|
||||
re.search(rf'(?:\W|^)+{token}(?:\W|$)+',
|
||||
title,
|
||||
re.IGNORECASE | re.UNICODE)
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .unittestcompat import TestCase
|
||||
from headphones.helpers import clean_name, is_valid_date, age
|
||||
from headphones.helpers import clean_name, is_valid_date, age, has_token
|
||||
|
||||
|
||||
class HelpersTest(TestCase):
|
||||
@@ -56,3 +56,18 @@ class HelpersTest(TestCase):
|
||||
]
|
||||
for input, expected, desc in test_cases:
|
||||
self.assertEqual(is_valid_date(input), expected, desc)
|
||||
|
||||
def test_has_token(self):
|
||||
"""helpers: has_token()"""
|
||||
self.assertEqual(
|
||||
has_token("a cat ran", "cat"),
|
||||
True,
|
||||
"return True if token is in string"
|
||||
)
|
||||
self.assertEqual(
|
||||
has_token("acatran", "cat"),
|
||||
False,
|
||||
"return False if token is part of another word"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from beets import config as beetsconfig
|
||||
from beets import logging as beetslogging
|
||||
from mediafile import MediaFile, FileTypeError, UnreadableFileError
|
||||
from beetsplug import lyrics as beetslyrics
|
||||
from headphones import notifiers, utorrent, transmission, deluge, qbittorrent
|
||||
from headphones import notifiers, utorrent, transmission, deluge, qbittorrent, soulseek
|
||||
from headphones import db, albumart, librarysync
|
||||
from headphones import logger, helpers, mb, music_encoder
|
||||
from headphones import metadata
|
||||
@@ -36,18 +36,44 @@ postprocessor_lock = threading.Lock()
|
||||
|
||||
|
||||
def checkFolder():
|
||||
logger.debug("Checking download folder for completed downloads (only snatched ones).")
|
||||
logger.info("Checking download folder for completed downloads (only snatched ones).")
|
||||
|
||||
with postprocessor_lock:
|
||||
myDB = db.DBConnection()
|
||||
snatched = myDB.select('SELECT * from snatched WHERE Status="Snatched"')
|
||||
|
||||
# If soulseek is used, this part will get the status from the soulseek api and return completed and errored albums
|
||||
completed_albums, errored_albums = set(), set()
|
||||
if any(album['Kind'] == 'soulseek' for album in snatched):
|
||||
completed_albums, errored_albums = soulseek.download_completed()
|
||||
|
||||
for album in snatched:
|
||||
if album['FolderName']:
|
||||
folder_name = album['FolderName']
|
||||
single = False
|
||||
if album['Kind'] == 'nzb':
|
||||
download_dir = headphones.CONFIG.DOWNLOAD_DIR
|
||||
if album['Kind'] == 'soulseek':
|
||||
if folder_name in errored_albums:
|
||||
# If the album had any tracks with errors in it, the whole download is considered faulty. Status will be reset to wanted.
|
||||
logger.info(f"Album with folder '{folder_name}' had errors during download. Setting status to 'Wanted'.")
|
||||
myDB.action('UPDATE albums SET Status="Wanted" WHERE AlbumID=? AND Status="Snatched"', (album['AlbumID'],))
|
||||
|
||||
# Folder will be removed from configured complete and Incomplete directory
|
||||
complete_path = os.path.join(headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR, folder_name)
|
||||
incomplete_path = os.path.join(headphones.CONFIG.SOULSEEK_INCOMPLETE_DOWNLOAD_DIR, folder_name)
|
||||
for path in [complete_path, incomplete_path]:
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
except Exception as e:
|
||||
pass
|
||||
continue
|
||||
elif folder_name in completed_albums:
|
||||
download_dir = headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR
|
||||
else:
|
||||
continue
|
||||
elif album['Kind'] == 'nzb':
|
||||
download_dir = headphones.CONFIG.DOWNLOAD_DIR
|
||||
elif album['Kind'] == 'bandcamp':
|
||||
download_dir = headphones.CONFIG.BANDCAMP_DIR
|
||||
else:
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY and headphones.CONFIG.TORRENT_DOWNLOADER == 3:
|
||||
download_dir = headphones.CONFIG.DELUGE_DONE_DIRECTORY
|
||||
@@ -289,7 +315,7 @@ def verify(albumid, albumpath, Kind=None, forced=False, keep_original_folder=Fal
|
||||
logger.debug('Metadata check failed. Verifying filenames...')
|
||||
for downloaded_track in downloaded_track_list:
|
||||
track_name = os.path.splitext(downloaded_track)[0]
|
||||
split_track_name = re.sub('[\.\-\_]', ' ', track_name).lower()
|
||||
split_track_name = re.sub(r'[\.\-\_]', r' ', track_name).lower()
|
||||
for track in tracks:
|
||||
|
||||
if not track['TrackTitle']:
|
||||
@@ -1170,8 +1196,14 @@ def forcePostProcess(dir=None, expand_subfolders=True, album_dir=None, keep_orig
|
||||
download_dirs.append(dir)
|
||||
if headphones.CONFIG.DOWNLOAD_DIR and not dir:
|
||||
download_dirs.append(headphones.CONFIG.DOWNLOAD_DIR)
|
||||
if headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR and not dir:
|
||||
download_dirs.append(headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR)
|
||||
if headphones.CONFIG.DOWNLOAD_TORRENT_DIR and not dir:
|
||||
download_dirs.append(headphones.CONFIG.DOWNLOAD_TORRENT_DIR)
|
||||
download_dirs.append(
|
||||
headphones.CONFIG.DOWNLOAD_TORRENT_DIR.encode(headphones.SYS_ENCODING, 'replace'))
|
||||
if headphones.CONFIG.BANDCAMP and not dir:
|
||||
download_dirs.append(
|
||||
headphones.CONFIG.BANDCAMP_DIR.encode(headphones.SYS_ENCODING, 'replace'))
|
||||
|
||||
# If DOWNLOAD_DIR and DOWNLOAD_TORRENT_DIR are the same, remove the duplicate to prevent us from trying to process the same folder twice.
|
||||
download_dirs = list(set(download_dirs))
|
||||
|
||||
@@ -42,19 +42,22 @@ class Rutracker(object):
|
||||
'login_password': headphones.CONFIG.RUTRACKER_PASSWORD,
|
||||
'login': b'\xc2\xf5\xee\xe4' # '%C2%F5%EE%E4'
|
||||
}
|
||||
headers = {
|
||||
'User-Agent' : 'Headphones'
|
||||
}
|
||||
|
||||
logger.info("Attempting to log in to rutracker...")
|
||||
|
||||
try:
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False)
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers)
|
||||
# try again
|
||||
if not self.has_bb_session_cookie(r):
|
||||
time.sleep(10)
|
||||
if headphones.CONFIG.RUTRACKER_COOKIE:
|
||||
logger.info("Attempting to log in using predefined cookie...")
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, cookies={'bb_session': headphones.CONFIG.RUTRACKER_COOKIE})
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers, cookies={'bb_session': headphones.CONFIG.RUTRACKER_COOKIE})
|
||||
else:
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False)
|
||||
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers)
|
||||
if self.has_bb_session_cookie(r):
|
||||
self.loggedin = True
|
||||
logger.info("Successfully logged in to rutracker")
|
||||
@@ -113,7 +116,10 @@ class Rutracker(object):
|
||||
Parse the search results and return valid torrent list
|
||||
"""
|
||||
try:
|
||||
headers = {'Referer': self.search_referer}
|
||||
headers = {
|
||||
'Referer': self.search_referer,
|
||||
'User-Agent' : 'Headphones'
|
||||
}
|
||||
r = self.session.get(url=searchurl, headers=headers, timeout=self.timeout)
|
||||
soup = BeautifulSoup(r.content, 'html.parser')
|
||||
|
||||
@@ -183,7 +189,10 @@ class Rutracker(object):
|
||||
downloadurl = 'https://rutracker.org/forum/dl.php?t=' + torrent_id
|
||||
cookie = {'bb_dl': torrent_id}
|
||||
try:
|
||||
headers = {'Referer': url}
|
||||
headers = {
|
||||
'Referer': url,
|
||||
'User-Agent' : 'Headphones'
|
||||
}
|
||||
r = self.session.post(url=downloadurl, cookies=cookie, headers=headers,
|
||||
timeout=self.timeout)
|
||||
return r.content
|
||||
|
||||
@@ -37,10 +37,29 @@ from unidecode import unidecode
|
||||
|
||||
import headphones
|
||||
from headphones.common import USER_AGENT
|
||||
from headphones.helpers import (
|
||||
bytes_to_mb,
|
||||
has_token,
|
||||
piratesize,
|
||||
replace_all,
|
||||
replace_illegal_chars,
|
||||
sab_replace_dots,
|
||||
sab_replace_spaces,
|
||||
sab_sanitize_foldername,
|
||||
split_string
|
||||
)
|
||||
from headphones.types import Result
|
||||
from headphones import logger, db, helpers, classes, sab, nzbget, request
|
||||
from headphones import utorrent, transmission, notifiers, rutracker, deluge, qbittorrent
|
||||
|
||||
from headphones import logger, db, classes, sab, nzbget, request
|
||||
from headphones import (
|
||||
bandcamp,
|
||||
deluge,
|
||||
notifiers,
|
||||
qbittorrent,
|
||||
rutracker,
|
||||
soulseek,
|
||||
transmission,
|
||||
utorrent,
|
||||
)
|
||||
|
||||
# Magnet to torrent services, for Black hole. Stolen from CouchPotato.
|
||||
TORRENT_TO_MAGNET_SERVICES = [
|
||||
@@ -137,7 +156,7 @@ def calculate_torrent_hash(link, data=None):
|
||||
"""
|
||||
|
||||
if link.startswith("magnet:"):
|
||||
torrent_hash = re.findall("urn:btih:([\w]{32,40})", link)[0]
|
||||
torrent_hash = re.findall(r"urn:btih:([\w]{32,40})", link)[0]
|
||||
if len(torrent_hash) == 32:
|
||||
torrent_hash = b16encode(b32decode(torrent_hash)).lower()
|
||||
elif data:
|
||||
@@ -261,6 +280,8 @@ def strptime_musicbrainz(date_str):
|
||||
|
||||
|
||||
def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
|
||||
|
||||
|
||||
NZB_PROVIDERS = (headphones.CONFIG.HEADPHONES_INDEXER or
|
||||
headphones.CONFIG.NEWZNAB or
|
||||
headphones.CONFIG.NZBSORG or
|
||||
@@ -284,25 +305,33 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
|
||||
[album['AlbumID']])[0][0]
|
||||
|
||||
if headphones.CONFIG.PREFER_TORRENTS == 0 and not choose_specific_download:
|
||||
|
||||
if NZB_PROVIDERS and NZB_DOWNLOADERS:
|
||||
results = searchNZB(album, new, losslessOnly, albumlength)
|
||||
|
||||
if not results and TORRENT_PROVIDERS:
|
||||
results = searchTorrent(album, new, losslessOnly, albumlength)
|
||||
|
||||
elif headphones.CONFIG.PREFER_TORRENTS == 1 and not choose_specific_download:
|
||||
if not results and headphones.CONFIG.BANDCAMP:
|
||||
results = searchBandcamp(album, new, albumlength)
|
||||
|
||||
elif headphones.CONFIG.PREFER_TORRENTS == 1 and not choose_specific_download:
|
||||
if TORRENT_PROVIDERS:
|
||||
results = searchTorrent(album, new, losslessOnly, albumlength)
|
||||
|
||||
if not results and NZB_PROVIDERS and NZB_DOWNLOADERS:
|
||||
results = searchNZB(album, new, losslessOnly, albumlength)
|
||||
|
||||
if not results and headphones.CONFIG.BANDCAMP:
|
||||
results = searchBandcamp(album, new, albumlength)
|
||||
|
||||
elif headphones.CONFIG.PREFER_TORRENTS == 2 and not choose_specific_download:
|
||||
results = searchSoulseek(album, new, losslessOnly, albumlength)
|
||||
|
||||
else:
|
||||
|
||||
nzb_results = None
|
||||
torrent_results = None
|
||||
bandcamp_results = None
|
||||
|
||||
if NZB_PROVIDERS and NZB_DOWNLOADERS:
|
||||
nzb_results = searchNZB(album, new, losslessOnly, albumlength, choose_specific_download)
|
||||
@@ -311,13 +340,16 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
|
||||
torrent_results = searchTorrent(album, new, losslessOnly, albumlength,
|
||||
choose_specific_download)
|
||||
|
||||
if headphones.CONFIG.BANDCAMP:
|
||||
bandcamp_results = searchBandcamp(album, new, albumlength)
|
||||
|
||||
if not nzb_results:
|
||||
nzb_results = []
|
||||
|
||||
if not torrent_results:
|
||||
torrent_results = []
|
||||
|
||||
results = nzb_results + torrent_results
|
||||
results = nzb_results + torrent_results + bandcamp_results
|
||||
|
||||
if choose_specific_download:
|
||||
return results
|
||||
@@ -338,6 +370,7 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
|
||||
(data, result) = preprocess(sorted_search_results)
|
||||
|
||||
if data and result:
|
||||
#print(f'going to send stuff to downloader. data: {data}, album: {album}')
|
||||
send_to_downloader(data, result, album)
|
||||
|
||||
|
||||
@@ -360,7 +393,7 @@ def more_filtering(results, album, albumlength, new):
|
||||
logger.debug('Target bitrate: %s kbps' % headphones.CONFIG.PREFERRED_BITRATE)
|
||||
if albumlength:
|
||||
targetsize = albumlength / 1000 * int(headphones.CONFIG.PREFERRED_BITRATE) * 128
|
||||
logger.info('Target size: %s' % helpers.bytes_to_mb(targetsize))
|
||||
logger.info('Target size: %s' % bytes_to_mb(targetsize))
|
||||
if headphones.CONFIG.PREFERRED_BITRATE_LOW_BUFFER:
|
||||
low_size_limit = targetsize * int(
|
||||
headphones.CONFIG.PREFERRED_BITRATE_LOW_BUFFER) / 100
|
||||
@@ -377,14 +410,14 @@ def more_filtering(results, album, albumlength, new):
|
||||
if low_size_limit and result.size < low_size_limit:
|
||||
logger.info(
|
||||
f"{result.title} from {result.provider} is too small for this album. "
|
||||
f"(Size: {result.size}, MinSize: {helpers.bytes_to_mb(low_size_limit)})"
|
||||
f"(Size: {result.size}, MinSize: {bytes_to_mb(low_size_limit)})"
|
||||
)
|
||||
continue
|
||||
|
||||
if high_size_limit and result.size > high_size_limit:
|
||||
logger.info(
|
||||
f"{result.title} from {result.provider} is too large for this album. "
|
||||
f"(Size: {result.size}, MaxSize: {helpers.bytes_to_mb(high_size_limit)})"
|
||||
f"(Size: {result.size}, MaxSize: {bytes_to_mb(high_size_limit)})"
|
||||
)
|
||||
# Keep lossless results if there are no good lossy matches
|
||||
if not (allow_lossless and 'flac' in result.title.lower()):
|
||||
@@ -424,7 +457,7 @@ def sort_search_results(resultlist, album, new, albumlength):
|
||||
|
||||
# Add a priority if it has any of the preferred words
|
||||
results_with_priority = []
|
||||
preferred_words = helpers.split_string(headphones.CONFIG.PREFERRED_WORDS)
|
||||
preferred_words = split_string(headphones.CONFIG.PREFERRED_WORDS)
|
||||
for result in resultlist:
|
||||
priority = 0
|
||||
for word in preferred_words:
|
||||
@@ -502,6 +535,10 @@ def get_year_from_release_date(release_date):
|
||||
return year
|
||||
|
||||
|
||||
def searchBandcamp(album, new=False, albumlength=None):
|
||||
return bandcamp.search(album)
|
||||
|
||||
|
||||
def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
choose_specific_download=False):
|
||||
reldate = album['ReleaseDate']
|
||||
@@ -521,8 +558,8 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
':': ''
|
||||
}
|
||||
|
||||
cleanalbum = unidecode(helpers.replace_all(album['AlbumTitle'], replacements)).strip()
|
||||
cleanartist = unidecode(helpers.replace_all(album['ArtistName'], replacements)).strip()
|
||||
cleanalbum = unidecode(replace_all(album['AlbumTitle'], replacements)).strip()
|
||||
cleanartist = unidecode(replace_all(album['ArtistName'], replacements)).strip()
|
||||
|
||||
# Use the provided search term if available, otherwise build a search term
|
||||
if album['SearchTerm']:
|
||||
@@ -542,8 +579,8 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
term = cleanartist + ' ' + cleanalbum
|
||||
|
||||
# Replace bad characters in the term
|
||||
term = re.sub('[\.\-\/]', ' ', term)
|
||||
artistterm = re.sub('[\.\-\/]', ' ', cleanartist)
|
||||
term = re.sub(r'[\.\-\/]', r' ', term)
|
||||
artistterm = re.sub(r'[\.\-\/]', r' ', cleanartist)
|
||||
|
||||
# If Preferred Bitrate and High Limit and Allow Lossless then get both lossy and lossless
|
||||
if headphones.CONFIG.PREFERRED_QUALITY == 2 and headphones.CONFIG.PREFERRED_BITRATE and headphones.CONFIG.PREFERRED_BITRATE_HIGH_BUFFER and headphones.CONFIG.PREFERRED_BITRATE_ALLOW_LOSSLESS:
|
||||
@@ -599,7 +636,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
size = int(item.links[1]['length'])
|
||||
|
||||
resultlist.append(Result(title, size, url, provider, 'nzb', True))
|
||||
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
|
||||
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
|
||||
except Exception as e:
|
||||
logger.error("An unknown error occurred trying to parse the feed: %s" % e)
|
||||
|
||||
@@ -670,7 +707,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
size = int(item.links[1]['length'])
|
||||
if all(word.lower() in title.lower() for word in term.split()):
|
||||
logger.info(
|
||||
'Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
|
||||
'Found %s. Size: %s' % (title, bytes_to_mb(size)))
|
||||
resultlist.append(Result(title, size, url, provider, 'nzb', True))
|
||||
else:
|
||||
logger.info('Skipping %s, not all search term words found' % title)
|
||||
@@ -720,7 +757,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
size = int(item.links[1]['length'])
|
||||
|
||||
resultlist.append(Result(title, size, url, provider, 'nzb', True))
|
||||
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
|
||||
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled exception while parsing feed")
|
||||
|
||||
@@ -767,7 +804,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
size = int(item['sizebytes'])
|
||||
|
||||
resultlist.append(Result(title, size, url, provider, 'nzb', True))
|
||||
logger.info('Found %s. Size: %s', title, helpers.bytes_to_mb(size))
|
||||
logger.info('Found %s. Size: %s', title, bytes_to_mb(size))
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled exception")
|
||||
|
||||
@@ -790,7 +827,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
|
||||
def send_to_downloader(data, result, album):
|
||||
logger.info(
|
||||
f"Found best result from {result.provider}: <a href=\"{result.url}\">"
|
||||
f"{result.title}</a> - {helpers.bytes_to_mb(result.size)}"
|
||||
f"{result.title}</a> - {bytes_to_mb(result.size)}"
|
||||
)
|
||||
# Get rid of any dodgy chars here so we can prevent sab from renaming our downloads
|
||||
kind = result.kind
|
||||
@@ -798,7 +835,7 @@ def send_to_downloader(data, result, album):
|
||||
torrentid = None
|
||||
|
||||
if kind == 'nzb':
|
||||
folder_name = helpers.sab_sanitize_foldername(result.title)
|
||||
folder_name = sab_sanitize_foldername(result.title)
|
||||
|
||||
if headphones.CONFIG.NZB_DOWNLOADER == 1:
|
||||
|
||||
@@ -820,9 +857,9 @@ def send_to_downloader(data, result, album):
|
||||
(replace_spaces, replace_dots) = sab.checkConfig()
|
||||
|
||||
if replace_dots:
|
||||
folder_name = helpers.sab_replace_dots(folder_name)
|
||||
folder_name = sab_replace_dots(folder_name)
|
||||
if replace_spaces:
|
||||
folder_name = helpers.sab_replace_spaces(folder_name)
|
||||
folder_name = sab_replace_spaces(folder_name)
|
||||
|
||||
else:
|
||||
nzb_name = folder_name + '.nzb'
|
||||
@@ -839,6 +876,15 @@ def send_to_downloader(data, result, album):
|
||||
except Exception as e:
|
||||
logger.error('Couldn\'t write NZB file: %s', e)
|
||||
return
|
||||
elif kind == 'bandcamp':
|
||||
folder_name = bandcamp.download(album, result)
|
||||
logger.info("Setting folder_name to: {}".format(folder_name))
|
||||
|
||||
|
||||
elif kind == 'soulseek':
|
||||
soulseek.download(user=result.user, filelist=result.files)
|
||||
folder_name = result.folder
|
||||
|
||||
else:
|
||||
folder_name = '%s - %s [%s]' % (
|
||||
unidecode(album['ArtistName']).replace('/', '_'),
|
||||
@@ -849,7 +895,7 @@ def send_to_downloader(data, result, album):
|
||||
if headphones.CONFIG.TORRENT_DOWNLOADER == 0:
|
||||
|
||||
# Get torrent name from .torrent, this is usually used by the torrent client as the folder name
|
||||
torrent_name = helpers.replace_illegal_chars(folder_name) + '.torrent'
|
||||
torrent_name = replace_illegal_chars(folder_name) + '.torrent'
|
||||
download_path = os.path.join(headphones.CONFIG.TORRENTBLACKHOLE_DIR, torrent_name)
|
||||
|
||||
if result.url.lower().startswith("magnet:"):
|
||||
@@ -954,10 +1000,6 @@ def send_to_downloader(data, result, album):
|
||||
logger.error("Error sending torrent to Deluge. Are you sure it's running? Maybe the torrent already exists?")
|
||||
return
|
||||
|
||||
# This pauses the torrent right after it is added
|
||||
if headphones.CONFIG.DELUGE_PAUSED:
|
||||
deluge.setTorrentPause({'hash': torrentid})
|
||||
|
||||
# Set Label
|
||||
if headphones.CONFIG.DELUGE_LABEL:
|
||||
deluge.setTorrentLabel({'hash': torrentid})
|
||||
@@ -967,10 +1009,6 @@ def send_to_downloader(data, result, album):
|
||||
if seed_ratio is not None:
|
||||
deluge.setSeedRatio({'hash': torrentid, 'ratio': seed_ratio})
|
||||
|
||||
# Set move-to directory
|
||||
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
|
||||
deluge.setTorrentPath({'hash': torrentid})
|
||||
|
||||
# Get folder name from Deluge, it's usually the torrent name
|
||||
folder_name = deluge.getTorrentFolder({'hash': torrentid})
|
||||
if folder_name:
|
||||
@@ -1156,7 +1194,7 @@ def send_to_downloader(data, result, album):
|
||||
|
||||
|
||||
def verifyresult(title, artistterm, term, lossless):
|
||||
title = re.sub('[\.\-\/\_]', ' ', title)
|
||||
title = re.sub(r'[\.\-\/\_]', r' ', title)
|
||||
|
||||
# if artistterm != 'Various Artists':
|
||||
#
|
||||
@@ -1188,16 +1226,16 @@ def verifyresult(title, artistterm, term, lossless):
|
||||
return False
|
||||
|
||||
if headphones.CONFIG.IGNORED_WORDS:
|
||||
for each_word in helpers.split_string(headphones.CONFIG.IGNORED_WORDS):
|
||||
for each_word in split_string(headphones.CONFIG.IGNORED_WORDS):
|
||||
if each_word.lower() in title.lower():
|
||||
logger.info("Removed '%s' from results because it contains ignored word: '%s'",
|
||||
title, each_word)
|
||||
return False
|
||||
|
||||
if headphones.CONFIG.REQUIRED_WORDS:
|
||||
for each_word in helpers.split_string(headphones.CONFIG.REQUIRED_WORDS):
|
||||
for each_word in split_string(headphones.CONFIG.REQUIRED_WORDS):
|
||||
if ' OR ' in each_word:
|
||||
or_words = helpers.split_string(each_word, 'OR')
|
||||
or_words = split_string(each_word, 'OR')
|
||||
if any(word.lower() in title.lower() for word in or_words):
|
||||
continue
|
||||
else:
|
||||
@@ -1219,23 +1257,23 @@ def verifyresult(title, artistterm, term, lossless):
|
||||
title, each_word)
|
||||
return False
|
||||
|
||||
tokens = re.split('\W', term, re.IGNORECASE | re.UNICODE)
|
||||
tokens = re.split(r'\W', term, re.IGNORECASE | re.UNICODE)
|
||||
|
||||
for token in tokens:
|
||||
|
||||
if not token:
|
||||
continue
|
||||
if token == 'Various' or token == 'Artists' or token == 'VA':
|
||||
continue
|
||||
if not re.search('(?:\W|^)+' + token + '(?:\W|$)+', title, re.IGNORECASE | re.UNICODE):
|
||||
if not has_token(title, token):
|
||||
cleantoken = ''.join(c for c in token if c not in string.punctuation)
|
||||
if not not re.search('(?:\W|^)+' + cleantoken + '(?:\W|$)+', title,
|
||||
re.IGNORECASE | re.UNICODE):
|
||||
if not has_token(title, cleantoken):
|
||||
dic = {'!': 'i', '$': 's'}
|
||||
dumbtoken = helpers.replace_all(token, dic)
|
||||
if not not re.search('(?:\W|^)+' + dumbtoken + '(?:\W|$)+', title,
|
||||
re.IGNORECASE | re.UNICODE):
|
||||
logger.info("Removed from results: %s (missing tokens: %s and %s)", title,
|
||||
token, cleantoken)
|
||||
dumbtoken = replace_all(token, dic)
|
||||
if not has_token(title, dumbtoken):
|
||||
logger.info(
|
||||
"Removed from results: %s (missing tokens: [%s, %s, %s])",
|
||||
title, token, cleantoken, dumbtoken)
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -1264,9 +1302,9 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
'*': ''
|
||||
}
|
||||
|
||||
semi_cleanalbum = helpers.replace_all(album['AlbumTitle'], replacements)
|
||||
semi_cleanalbum = replace_all(album['AlbumTitle'], replacements)
|
||||
cleanalbum = unidecode(semi_cleanalbum)
|
||||
semi_cleanartist = helpers.replace_all(album['ArtistName'], replacements)
|
||||
semi_cleanartist = replace_all(album['ArtistName'], replacements)
|
||||
cleanartist = unidecode(semi_cleanartist)
|
||||
|
||||
# Use provided term if available, otherwise build our own (this code needs to be cleaned up since a lot
|
||||
@@ -1293,12 +1331,12 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
else:
|
||||
usersearchterm = ''
|
||||
|
||||
semi_clean_artist_term = re.sub('[\.\-\/]', ' ', semi_cleanartist)
|
||||
semi_clean_album_term = re.sub('[\.\-\/]', ' ', semi_cleanalbum)
|
||||
semi_clean_artist_term = re.sub(r'[\.\-\/]', r' ', semi_cleanartist)
|
||||
semi_clean_album_term = re.sub(r'[\.\-\/]', r' ', semi_cleanalbum)
|
||||
# Replace bad characters in the term
|
||||
term = re.sub('[\.\-\/]', ' ', term)
|
||||
artistterm = re.sub('[\.\-\/]', ' ', cleanartist)
|
||||
albumterm = re.sub('[\.\-\/]', ' ', cleanalbum)
|
||||
term = re.sub(r'[\.\-\/]', r' ', term)
|
||||
artistterm = re.sub(r'[\.\-\/]', r' ', cleanartist)
|
||||
albumterm = re.sub(r'[\.\-\/]', r' ', cleanalbum)
|
||||
|
||||
# If Preferred Bitrate and High Limit and Allow Lossless then get both lossy and lossless
|
||||
if headphones.CONFIG.PREFERRED_QUALITY == 2 and headphones.CONFIG.PREFERRED_BITRATE and headphones.CONFIG.PREFERRED_BITRATE_HIGH_BUFFER and headphones.CONFIG.PREFERRED_BITRATE_ALLOW_LOSSLESS:
|
||||
@@ -1401,7 +1439,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
|
||||
if all(word.lower() in title.lower() for word in term.split()):
|
||||
if size < maxsize and minimumseeders < seeders:
|
||||
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
|
||||
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
|
||||
resultlist.append(Result(title, size, url, provider, 'torrent', True))
|
||||
else:
|
||||
logger.info(
|
||||
@@ -1477,7 +1515,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
size = int(desc_match.group(1))
|
||||
url = item.link
|
||||
resultlist.append(Result(title, size, url, provider, 'torrent', True))
|
||||
logger.info('Found %s. Size: %s', title, helpers.bytes_to_mb(size))
|
||||
logger.info('Found %s. Size: %s', title, bytes_to_mb(size))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"An error occurred while trying to parse the response from Waffles.ch: %s",
|
||||
@@ -1761,7 +1799,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
# Pirate Bay
|
||||
if headphones.CONFIG.PIRATEBAY:
|
||||
provider = "The Pirate Bay"
|
||||
tpb_term = term.replace("!", "").replace("'", " ")
|
||||
tpb_term = term.replace("!", "").replace("'", " ").replace(" ", "%20")
|
||||
|
||||
# Use proxy if specified
|
||||
if headphones.CONFIG.PIRATEBAY_PROXY_URL:
|
||||
@@ -1793,6 +1831,8 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
# Process content
|
||||
if data:
|
||||
rows = data.select('table tbody tr')
|
||||
if not rows:
|
||||
rows = data.select('table tr')
|
||||
|
||||
if not rows:
|
||||
logger.info("No results found from The Pirate Bay using term: %s" % tpb_term)
|
||||
@@ -1820,7 +1860,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
|
||||
formatted_size = re.search('Size (.*),', str(item)).group(1).replace(
|
||||
'\xa0', ' ')
|
||||
size = helpers.piratesize(formatted_size)
|
||||
size = piratesize(formatted_size)
|
||||
|
||||
if size < maxsize and minimumseeders < seeds and url is not None:
|
||||
match = True
|
||||
@@ -1874,7 +1914,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
"href"] # Magnet link. The actual download link is not based on the URL
|
||||
|
||||
formatted_size = item.select("td.size-row")[0].text
|
||||
size = helpers.piratesize(formatted_size)
|
||||
size = piratesize(formatted_size)
|
||||
|
||||
if size < maxsize and minimumseeders < seeds and url is not None:
|
||||
match = True
|
||||
@@ -1901,22 +1941,49 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
|
||||
return results
|
||||
|
||||
|
||||
def searchSoulseek(album, new=False, losslessOnly=False, albumlength=None):
|
||||
# Not using some of the input stuff for now or ever
|
||||
replacements = {
|
||||
'...': '',
|
||||
' & ': ' ',
|
||||
' = ': ' ',
|
||||
'?': '',
|
||||
'$': '',
|
||||
' + ': ' ',
|
||||
'"': '',
|
||||
',': '',
|
||||
'*': '',
|
||||
'.': '',
|
||||
':': ''
|
||||
}
|
||||
|
||||
num_tracks = get_album_track_count(album['AlbumID'])
|
||||
year = get_year_from_release_date(album['ReleaseDate'])
|
||||
cleanalbum = unidecode(helpers.replace_all(album['AlbumTitle'], replacements)).strip()
|
||||
cleanartist = unidecode(helpers.replace_all(album['ArtistName'], replacements)).strip()
|
||||
|
||||
results = soulseek.search(artist=cleanartist, album=cleanalbum, year=year, losslessOnly=losslessOnly, num_tracks=num_tracks)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_album_track_count(album_id):
|
||||
# Not sure if this should be considered a helper function.
|
||||
myDB = db.DBConnection()
|
||||
track_count = myDB.select('SELECT COUNT(*) as count FROM tracks WHERE AlbumID=?', [album_id])[0]['count']
|
||||
return track_count
|
||||
|
||||
|
||||
# THIS IS KIND OF A MESS AND PROBABLY NEEDS TO BE CLEANED UP
|
||||
|
||||
|
||||
def preprocess(resultlist):
|
||||
for result in resultlist:
|
||||
headers = {'User-Agent': USER_AGENT}
|
||||
|
||||
if result.provider in ["The Pirate Bay", "Old Pirate Bay"]:
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 6.3; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) \
|
||||
Chrome/41.0.2243.2 Safari/537.36'
|
||||
}
|
||||
else:
|
||||
headers = {'User-Agent': USER_AGENT}
|
||||
|
||||
if result.kind == 'soulseek':
|
||||
return True, result
|
||||
|
||||
if result.kind == 'torrent':
|
||||
|
||||
# rutracker always needs the torrent data
|
||||
@@ -1962,12 +2029,24 @@ def preprocess(resultlist):
|
||||
return True, result
|
||||
|
||||
# Download the torrent file
|
||||
|
||||
if result.provider in ["The Pirate Bay", "Old Pirate Bay"]:
|
||||
headers = {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 6.3; Win64; x64) \
|
||||
AppleWebKit/537.36 (KHTML, like Gecko) \
|
||||
Chrome/41.0.2243.2 Safari/537.36'
|
||||
}
|
||||
|
||||
return request.request_content(url=result.url, headers=headers), result
|
||||
|
||||
if result.kind == 'magnet':
|
||||
elif result.kind == 'magnet':
|
||||
magnet_link = result.url
|
||||
return "d10:magnet-uri%d:%se" % (len(magnet_link), magnet_link), result
|
||||
|
||||
elif result.kind == 'bandcamp':
|
||||
return True, result
|
||||
|
||||
else:
|
||||
if result.provider == 'headphones':
|
||||
return request.request_content(
|
||||
|
||||
185
headphones/soulseek.py
Normal file
185
headphones/soulseek.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from collections import defaultdict, namedtuple
|
||||
import os
|
||||
import time
|
||||
import slskd_api
|
||||
import headphones
|
||||
from headphones import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
Result = namedtuple('Result', ['title', 'size', 'user', 'provider', 'type', 'matches', 'bandwidth', 'hasFreeUploadSlot', 'queueLength', 'files', 'kind', 'url', 'folder'])
|
||||
|
||||
def initialize_soulseek_client():
|
||||
host = headphones.CONFIG.SOULSEEK_API_URL
|
||||
api_key = headphones.CONFIG.SOULSEEK_API_KEY
|
||||
return slskd_api.SlskdClient(host=host, api_key=api_key)
|
||||
|
||||
# Search logic, calling search and processing fucntions
|
||||
def search(artist, album, year, num_tracks, losslessOnly):
|
||||
client = initialize_soulseek_client()
|
||||
|
||||
# Stage 1: Search with artist, album, year, and num_tracks
|
||||
results = execute_search(client, artist, album, year, losslessOnly)
|
||||
processed_results = process_results(results, losslessOnly, num_tracks)
|
||||
if processed_results:
|
||||
return processed_results
|
||||
|
||||
# Stage 2: If Stage 1 fails, search with artist, album, and num_tracks (excluding year)
|
||||
logger.info("Soulseek search stage 1 did not meet criteria. Retrying without year...")
|
||||
results = execute_search(client, artist, album, None, losslessOnly)
|
||||
processed_results = process_results(results, losslessOnly, num_tracks)
|
||||
if processed_results:
|
||||
return processed_results
|
||||
|
||||
# Stage 3: Final attempt, search only with artist and album
|
||||
logger.info("Soulseek search stage 2 did not meet criteria. Final attempt with only artist and album.")
|
||||
results = execute_search(client, artist, album, None, losslessOnly)
|
||||
processed_results = process_results(results, losslessOnly, num_tracks, ignore_track_count=True)
|
||||
|
||||
return processed_results
|
||||
|
||||
def execute_search(client, artist, album, year, losslessOnly):
|
||||
search_text = f"{artist} {album}"
|
||||
if year:
|
||||
search_text += f" {year}"
|
||||
if losslessOnly:
|
||||
search_text += ".flac"
|
||||
|
||||
# Actual search
|
||||
search_response = client.searches.search_text(searchText=search_text, filterResponses=True)
|
||||
search_id = search_response.get('id')
|
||||
|
||||
# Wait for search completion and return response
|
||||
while not client.searches.state(id=search_id).get('isComplete'):
|
||||
time.sleep(2)
|
||||
|
||||
return client.searches.search_responses(id=search_id)
|
||||
|
||||
# Processing the search result passed
|
||||
def process_results(results, losslessOnly, num_tracks, ignore_track_count=False):
|
||||
valid_extensions = {'.flac'} if losslessOnly else {'.mp3', '.flac'}
|
||||
albums = defaultdict(lambda: {'files': [], 'user': None, 'hasFreeUploadSlot': None, 'queueLength': None, 'uploadSpeed': None})
|
||||
|
||||
# Extract info from the api response and combine files at album level
|
||||
for result in results:
|
||||
user = result.get('username')
|
||||
hasFreeUploadSlot = result.get('hasFreeUploadSlot')
|
||||
queueLength = result.get('queueLength')
|
||||
uploadSpeed = result.get('uploadSpeed')
|
||||
|
||||
# Only handle .mp3 and .flac
|
||||
for file in result.get('files', []):
|
||||
filename = file.get('filename')
|
||||
file_extension = os.path.splitext(filename)[1].lower()
|
||||
if file_extension in valid_extensions:
|
||||
album_directory = os.path.dirname(filename)
|
||||
albums[album_directory]['files'].append(file)
|
||||
|
||||
# Update metadata only once per album_directory
|
||||
if albums[album_directory]['user'] is None:
|
||||
albums[album_directory].update({
|
||||
'user': user,
|
||||
'hasFreeUploadSlot': hasFreeUploadSlot,
|
||||
'queueLength': queueLength,
|
||||
'uploadSpeed': uploadSpeed,
|
||||
})
|
||||
|
||||
# Filter albums based on num_tracks, add bunch of useful info to the compiled album
|
||||
final_results = []
|
||||
for directory, album_data in albums.items():
|
||||
if ignore_track_count or len(album_data['files']) == num_tracks:
|
||||
album_title = os.path.basename(directory)
|
||||
total_size = sum(file.get('size', 0) for file in album_data['files'])
|
||||
final_results.append(Result(
|
||||
title=album_title,
|
||||
size=int(total_size),
|
||||
user=album_data['user'],
|
||||
provider="soulseek",
|
||||
type="soulseek",
|
||||
matches=True,
|
||||
bandwidth=album_data['uploadSpeed'],
|
||||
hasFreeUploadSlot=album_data['hasFreeUploadSlot'],
|
||||
queueLength=album_data['queueLength'],
|
||||
files=album_data['files'],
|
||||
kind='soulseek',
|
||||
url='http://thisisnot.needed', # URL is needed in other parts of the program.
|
||||
folder=os.path.basename(directory)
|
||||
))
|
||||
|
||||
return final_results
|
||||
|
||||
|
||||
def download(user, filelist):
|
||||
client = initialize_soulseek_client()
|
||||
client.transfers.enqueue(username=user, files=filelist)
|
||||
|
||||
|
||||
def download_completed():
|
||||
client = initialize_soulseek_client()
|
||||
all_downloads = client.transfers.get_all_downloads(includeRemoved=False)
|
||||
album_completion_tracker = {} # Tracks completion state of each album's songs
|
||||
album_errored_tracker = {} # Tracks albums with errored downloads
|
||||
|
||||
# Anything older than 24 hours will be canceled
|
||||
cutoff_time = datetime.now() - timedelta(hours=24)
|
||||
|
||||
# Identify errored and completed albums
|
||||
for download in all_downloads:
|
||||
directories = download.get('directories', [])
|
||||
for directory in directories:
|
||||
album_part = directory.get('directory', '').split('\\')[-1]
|
||||
files = directory.get('files', [])
|
||||
for file_data in files:
|
||||
state = file_data.get('state', '')
|
||||
requested_at_str = file_data.get('requestedAt', '1900-01-01 00:00:00')
|
||||
requested_at = parse_datetime(requested_at_str)
|
||||
|
||||
# Initialize or update album entry in trackers
|
||||
if album_part not in album_completion_tracker:
|
||||
album_completion_tracker[album_part] = {'total': 0, 'completed': 0, 'errored': 0}
|
||||
if album_part not in album_errored_tracker:
|
||||
album_errored_tracker[album_part] = False
|
||||
|
||||
album_completion_tracker[album_part]['total'] += 1
|
||||
|
||||
if 'Completed, Succeeded' in state:
|
||||
album_completion_tracker[album_part]['completed'] += 1
|
||||
elif 'Completed, Errored' in state or requested_at < cutoff_time:
|
||||
album_completion_tracker[album_part]['errored'] += 1
|
||||
album_errored_tracker[album_part] = True # Mark album as having errored downloads
|
||||
|
||||
# Identify errored albums
|
||||
errored_albums = {album for album, errored in album_errored_tracker.items() if errored}
|
||||
|
||||
# Cancel downloads for errored albums
|
||||
for download in all_downloads:
|
||||
directories = download.get('directories', [])
|
||||
for directory in directories:
|
||||
album_part = directory.get('directory', '').split('\\')[-1]
|
||||
files = directory.get('files', [])
|
||||
for file_data in files:
|
||||
if album_part in errored_albums:
|
||||
# Extract 'id' and 'username' for each file to cancel the download
|
||||
file_id = file_data.get('id', '')
|
||||
username = file_data.get('username', '')
|
||||
success = client.transfers.cancel_download(username, file_id)
|
||||
if not success:
|
||||
print(f"Failed to cancel download for file ID: {file_id}")
|
||||
|
||||
# Clear completed/canceled/errored stuff from client downloads
|
||||
try:
|
||||
client.transfers.remove_completed_downloads()
|
||||
except Exception as e:
|
||||
print(f"Failed to remove completed downloads: {e}")
|
||||
|
||||
# Identify completed albums
|
||||
completed_albums = {album for album, counts in album_completion_tracker.items() if counts['total'] == counts['completed']}
|
||||
|
||||
# Return both completed and errored albums
|
||||
return completed_albums, errored_albums
|
||||
|
||||
|
||||
def parse_datetime(datetime_string):
|
||||
# Parse the datetime api response
|
||||
if '.' in datetime_string:
|
||||
datetime_string = datetime_string[:datetime_string.index('.')+7]
|
||||
return datetime.strptime(datetime_string, '%Y-%m-%dT%H:%M:%S.%f')
|
||||
@@ -1183,6 +1183,7 @@ class WebInterface(object):
|
||||
"deluge_password": headphones.CONFIG.DELUGE_PASSWORD,
|
||||
"deluge_label": headphones.CONFIG.DELUGE_LABEL,
|
||||
"deluge_done_directory": headphones.CONFIG.DELUGE_DONE_DIRECTORY,
|
||||
"deluge_download_directory": headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY,
|
||||
"deluge_paused": checked(headphones.CONFIG.DELUGE_PAUSED),
|
||||
"utorrent_host": headphones.CONFIG.UTORRENT_HOST,
|
||||
"utorrent_username": headphones.CONFIG.UTORRENT_USERNAME,
|
||||
@@ -1197,6 +1198,8 @@ class WebInterface(object):
|
||||
"torrent_downloader_deluge": radio(headphones.CONFIG.TORRENT_DOWNLOADER, 3),
|
||||
"torrent_downloader_qbittorrent": radio(headphones.CONFIG.TORRENT_DOWNLOADER, 4),
|
||||
"download_dir": headphones.CONFIG.DOWNLOAD_DIR,
|
||||
"soulseek_download_dir": headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR,
|
||||
"soulseek_incomplete_download_dir": headphones.CONFIG.SOULSEEK_INCOMPLETE_DOWNLOAD_DIR,
|
||||
"use_blackhole": checked(headphones.CONFIG.BLACKHOLE),
|
||||
"blackhole_dir": headphones.CONFIG.BLACKHOLE_DIR,
|
||||
"usenet_retention": headphones.CONFIG.USENET_RETENTION,
|
||||
@@ -1296,6 +1299,7 @@ class WebInterface(object):
|
||||
"prefer_torrents_0": radio(headphones.CONFIG.PREFER_TORRENTS, 0),
|
||||
"prefer_torrents_1": radio(headphones.CONFIG.PREFER_TORRENTS, 1),
|
||||
"prefer_torrents_2": radio(headphones.CONFIG.PREFER_TORRENTS, 2),
|
||||
"prefer_torrents_3": radio(headphones.CONFIG.PREFER_TORRENTS, 3),
|
||||
"magnet_links_0": radio(headphones.CONFIG.MAGNET_LINKS, 0),
|
||||
"magnet_links_1": radio(headphones.CONFIG.MAGNET_LINKS, 1),
|
||||
"magnet_links_2": radio(headphones.CONFIG.MAGNET_LINKS, 2),
|
||||
@@ -1413,7 +1417,12 @@ class WebInterface(object):
|
||||
"join_enabled": checked(headphones.CONFIG.JOIN_ENABLED),
|
||||
"join_onsnatch": checked(headphones.CONFIG.JOIN_ONSNATCH),
|
||||
"join_apikey": headphones.CONFIG.JOIN_APIKEY,
|
||||
"join_deviceid": headphones.CONFIG.JOIN_DEVICEID
|
||||
"join_deviceid": headphones.CONFIG.JOIN_DEVICEID,
|
||||
"use_bandcamp": checked(headphones.CONFIG.BANDCAMP),
|
||||
"bandcamp_dir": headphones.CONFIG.BANDCAMP_DIR,
|
||||
'soulseek_api_url': headphones.CONFIG.SOULSEEK_API_URL,
|
||||
'soulseek_api_key': headphones.CONFIG.SOULSEEK_API_KEY,
|
||||
'use_soulseek': checked(headphones.CONFIG.SOULSEEK)
|
||||
}
|
||||
|
||||
for k, v in config.items():
|
||||
@@ -1482,7 +1491,7 @@ class WebInterface(object):
|
||||
"songkick_enabled", "songkick_filter_enabled",
|
||||
"mpc_enabled", "email_enabled", "email_ssl", "email_tls", "email_onsnatch",
|
||||
"customauth", "idtag", "deluge_paused",
|
||||
"join_enabled", "join_onsnatch"
|
||||
"join_enabled", "join_onsnatch", "use_bandcamp"
|
||||
]
|
||||
for checked_config in checked_configs:
|
||||
if checked_config not in kwargs:
|
||||
|
||||
@@ -57,9 +57,11 @@ These API's are described in the `CherryPy specification
|
||||
"""
|
||||
|
||||
try:
|
||||
import pkg_resources
|
||||
import importlib.metadata as importlib_metadata
|
||||
except ImportError:
|
||||
pass
|
||||
# fall back for python <= 3.7
|
||||
# This try/except can be removed with py <= 3.7 support
|
||||
import importlib_metadata
|
||||
|
||||
from threading import local as _local
|
||||
|
||||
@@ -109,7 +111,7 @@ tree = _cptree.Tree()
|
||||
|
||||
|
||||
try:
|
||||
__version__ = pkg_resources.require('cherrypy')[0].version
|
||||
__version__ = importlib_metadata.version('cherrypy')
|
||||
except Exception:
|
||||
__version__ = 'unknown'
|
||||
|
||||
@@ -181,24 +183,28 @@ def quickstart(root=None, script_name='', config=None):
|
||||
class _Serving(_local):
|
||||
"""An interface for registering request and response objects.
|
||||
|
||||
Rather than have a separate "thread local" object for the request and
|
||||
the response, this class works as a single threadlocal container for
|
||||
both objects (and any others which developers wish to define). In this
|
||||
way, we can easily dump those objects when we stop/start a new HTTP
|
||||
conversation, yet still refer to them as module-level globals in a
|
||||
thread-safe way.
|
||||
Rather than have a separate "thread local" object for the request
|
||||
and the response, this class works as a single threadlocal container
|
||||
for both objects (and any others which developers wish to define).
|
||||
In this way, we can easily dump those objects when we stop/start a
|
||||
new HTTP conversation, yet still refer to them as module-level
|
||||
globals in a thread-safe way.
|
||||
"""
|
||||
|
||||
request = _cprequest.Request(_httputil.Host('127.0.0.1', 80),
|
||||
_httputil.Host('127.0.0.1', 1111))
|
||||
"""The request object for the current thread.
|
||||
|
||||
In the main thread, and any threads which are not receiving HTTP
|
||||
requests, this is None.
|
||||
"""
|
||||
The request object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
response = _cprequest.Response()
|
||||
"""The response object for the current thread.
|
||||
|
||||
In the main thread, and any threads which are not receiving HTTP
|
||||
requests, this is None.
|
||||
"""
|
||||
The response object for the current thread. In the main thread,
|
||||
and any threads which are not receiving HTTP requests, this is None."""
|
||||
|
||||
def load(self, request, response):
|
||||
self.request = request
|
||||
@@ -316,8 +322,8 @@ class _GlobalLogManager(_cplogging.LogManager):
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Log the given message to the app.log or global log.
|
||||
|
||||
Log the given message to the app.log or global
|
||||
log as appropriate.
|
||||
Log the given message to the app.log or global log as
|
||||
appropriate.
|
||||
"""
|
||||
# Do NOT use try/except here. See
|
||||
# https://github.com/cherrypy/cherrypy/issues/945
|
||||
@@ -330,8 +336,8 @@ class _GlobalLogManager(_cplogging.LogManager):
|
||||
def access(self):
|
||||
"""Log an access message to the app.log or global log.
|
||||
|
||||
Log the given message to the app.log or global
|
||||
log as appropriate.
|
||||
Log the given message to the app.log or global log as
|
||||
appropriate.
|
||||
"""
|
||||
try:
|
||||
return request.app.log.access()
|
||||
|
||||
@@ -313,7 +313,10 @@ class Checker(object):
|
||||
|
||||
# -------------------- Specific config warnings -------------------- #
|
||||
def check_localhost(self):
|
||||
"""Warn if any socket_host is 'localhost'. See #711."""
|
||||
"""Warn if any socket_host is 'localhost'.
|
||||
|
||||
See #711.
|
||||
"""
|
||||
for k, v in cherrypy.config.items():
|
||||
if k == 'server.socket_host' and v == 'localhost':
|
||||
warnings.warn("The use of 'localhost' as a socket host can "
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
Configuration system for CherryPy.
|
||||
"""Configuration system for CherryPy.
|
||||
|
||||
Configuration in CherryPy is implemented via dictionaries. Keys are strings
|
||||
which name the mapped value, which may be of any type.
|
||||
@@ -132,8 +131,8 @@ def _if_filename_register_autoreload(ob):
|
||||
def merge(base, other):
|
||||
"""Merge one app config (from a dict, file, or filename) into another.
|
||||
|
||||
If the given config is a filename, it will be appended to
|
||||
the list of files to monitor for "autoreload" changes.
|
||||
If the given config is a filename, it will be appended to the list
|
||||
of files to monitor for "autoreload" changes.
|
||||
"""
|
||||
_if_filename_register_autoreload(other)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""CherryPy dispatchers.
|
||||
|
||||
A 'dispatcher' is the object which looks up the 'page handler' callable
|
||||
and collects config for the current request based on the path_info, other
|
||||
request attributes, and the application architecture. The core calls the
|
||||
dispatcher as early as possible, passing it a 'path_info' argument.
|
||||
and collects config for the current request based on the path_info,
|
||||
other request attributes, and the application architecture. The core
|
||||
calls the dispatcher as early as possible, passing it a 'path_info'
|
||||
argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
@@ -21,7 +22,6 @@ import cherrypy
|
||||
|
||||
|
||||
class PageHandler(object):
|
||||
|
||||
"""Callable which sets response.body."""
|
||||
|
||||
def __init__(self, callable, *args, **kwargs):
|
||||
@@ -64,8 +64,7 @@ class PageHandler(object):
|
||||
|
||||
|
||||
def test_callable_spec(callable, callable_args, callable_kwargs):
|
||||
"""
|
||||
Inspect callable and test to see if the given args are suitable for it.
|
||||
"""Inspect callable and test to see if the given args are suitable for it.
|
||||
|
||||
When an error occurs during the handler's invoking stage there are 2
|
||||
erroneous cases:
|
||||
@@ -252,16 +251,16 @@ else:
|
||||
|
||||
|
||||
class Dispatcher(object):
|
||||
|
||||
"""CherryPy Dispatcher which walks a tree of objects to find a handler.
|
||||
|
||||
The tree is rooted at cherrypy.request.app.root, and each hierarchical
|
||||
component in the path_info argument is matched to a corresponding nested
|
||||
attribute of the root object. Matching handlers must have an 'exposed'
|
||||
attribute which evaluates to True. The special method name "index"
|
||||
matches a URI which ends in a slash ("/"). The special method name
|
||||
"default" may match a portion of the path_info (but only when no longer
|
||||
substring of the path_info matches some other object).
|
||||
The tree is rooted at cherrypy.request.app.root, and each
|
||||
hierarchical component in the path_info argument is matched to a
|
||||
corresponding nested attribute of the root object. Matching handlers
|
||||
must have an 'exposed' attribute which evaluates to True. The
|
||||
special method name "index" matches a URI which ends in a slash
|
||||
("/"). The special method name "default" may match a portion of the
|
||||
path_info (but only when no longer substring of the path_info
|
||||
matches some other object).
|
||||
|
||||
This is the default, built-in dispatcher for CherryPy.
|
||||
"""
|
||||
@@ -306,9 +305,9 @@ class Dispatcher(object):
|
||||
|
||||
The second object returned will be a list of names which are
|
||||
'virtual path' components: parts of the URL which are dynamic,
|
||||
and were not used when looking up the handler.
|
||||
These virtual path components are passed to the handler as
|
||||
positional arguments.
|
||||
and were not used when looking up the handler. These virtual
|
||||
path components are passed to the handler as positional
|
||||
arguments.
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
app = request.app
|
||||
@@ -448,13 +447,11 @@ class Dispatcher(object):
|
||||
|
||||
|
||||
class MethodDispatcher(Dispatcher):
|
||||
|
||||
"""Additional dispatch based on cherrypy.request.method.upper().
|
||||
|
||||
Methods named GET, POST, etc will be called on an exposed class.
|
||||
The method names must be all caps; the appropriate Allow header
|
||||
will be output showing all capitalized method names as allowable
|
||||
HTTP verbs.
|
||||
Methods named GET, POST, etc will be called on an exposed class. The
|
||||
method names must be all caps; the appropriate Allow header will be
|
||||
output showing all capitalized method names as allowable HTTP verbs.
|
||||
|
||||
Note that the containing class must be exposed, not the methods.
|
||||
"""
|
||||
@@ -492,16 +489,14 @@ class MethodDispatcher(Dispatcher):
|
||||
|
||||
|
||||
class RoutesDispatcher(object):
|
||||
|
||||
"""A Routes based dispatcher for CherryPy."""
|
||||
|
||||
def __init__(self, full_result=False, **mapper_options):
|
||||
"""
|
||||
Routes dispatcher
|
||||
"""Routes dispatcher.
|
||||
|
||||
Set full_result to True if you wish the controller
|
||||
and the action to be passed on to the page handler
|
||||
parameters. By default they won't be.
|
||||
Set full_result to True if you wish the controller and the
|
||||
action to be passed on to the page handler parameters. By
|
||||
default they won't be.
|
||||
"""
|
||||
import routes
|
||||
self.full_result = full_result
|
||||
@@ -617,8 +612,7 @@ def XMLRPCDispatcher(next_dispatcher=Dispatcher()):
|
||||
|
||||
def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True,
|
||||
**domains):
|
||||
"""
|
||||
Select a different handler based on the Host header.
|
||||
"""Select a different handler based on the Host header.
|
||||
|
||||
This can be useful when running multiple sites within one CP server.
|
||||
It allows several domains to point to different parts of a single
|
||||
|
||||
@@ -136,19 +136,17 @@ from cherrypy.lib import httputil as _httputil
|
||||
|
||||
|
||||
class CherryPyException(Exception):
|
||||
|
||||
"""A base class for CherryPy exceptions."""
|
||||
pass
|
||||
|
||||
|
||||
class InternalRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised to switch to the handler for a different URL.
|
||||
|
||||
This exception will redirect processing to another path within the site
|
||||
(without informing the client). Provide the new path as an argument when
|
||||
raising the exception. Provide any params in the querystring for the new
|
||||
URL.
|
||||
This exception will redirect processing to another path within the
|
||||
site (without informing the client). Provide the new path as an
|
||||
argument when raising the exception. Provide any params in the
|
||||
querystring for the new URL.
|
||||
"""
|
||||
|
||||
def __init__(self, path, query_string=''):
|
||||
@@ -173,7 +171,6 @@ class InternalRedirect(CherryPyException):
|
||||
|
||||
|
||||
class HTTPRedirect(CherryPyException):
|
||||
|
||||
"""Exception raised when the request should be redirected.
|
||||
|
||||
This exception will force a HTTP redirect to the URL or URL's you give it.
|
||||
@@ -202,7 +199,7 @@ class HTTPRedirect(CherryPyException):
|
||||
"""The list of URL's to emit."""
|
||||
|
||||
encoding = 'utf-8'
|
||||
"""The encoding when passed urls are not native strings"""
|
||||
"""The encoding when passed urls are not native strings."""
|
||||
|
||||
def __init__(self, urls, status=None, encoding=None):
|
||||
self.urls = abs_urls = [
|
||||
@@ -230,8 +227,7 @@ class HTTPRedirect(CherryPyException):
|
||||
|
||||
@classproperty
|
||||
def default_status(cls):
|
||||
"""
|
||||
The default redirect status for the request.
|
||||
"""The default redirect status for the request.
|
||||
|
||||
RFC 2616 indicates a 301 response code fits our goal; however,
|
||||
browser support for 301 is quite messy. Use 302/303 instead. See
|
||||
@@ -249,8 +245,9 @@ class HTTPRedirect(CherryPyException):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPRedirect object and set its output without *raising* the exception.
|
||||
CherryPy uses this internally, but you can also use it to create
|
||||
an HTTPRedirect object and set its output without *raising* the
|
||||
exception.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
response.status = status = self.status
|
||||
@@ -339,7 +336,6 @@ def clean_headers(status):
|
||||
|
||||
|
||||
class HTTPError(CherryPyException):
|
||||
|
||||
"""Exception used to return an HTTP error code (4xx-5xx) to the client.
|
||||
|
||||
This exception can be used to automatically send a response using a
|
||||
@@ -358,7 +354,9 @@ class HTTPError(CherryPyException):
|
||||
"""
|
||||
|
||||
status = None
|
||||
"""The HTTP status code. May be of type int or str (with a Reason-Phrase).
|
||||
"""The HTTP status code.
|
||||
|
||||
May be of type int or str (with a Reason-Phrase).
|
||||
"""
|
||||
|
||||
code = None
|
||||
@@ -386,8 +384,9 @@ class HTTPError(CherryPyException):
|
||||
"""Modify cherrypy.response status, headers, and body to represent
|
||||
self.
|
||||
|
||||
CherryPy uses this internally, but you can also use it to create an
|
||||
HTTPError object and set its output without *raising* the exception.
|
||||
CherryPy uses this internally, but you can also use it to create
|
||||
an HTTPError object and set its output without *raising* the
|
||||
exception.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
|
||||
@@ -426,11 +425,10 @@ class HTTPError(CherryPyException):
|
||||
|
||||
|
||||
class NotFound(HTTPError):
|
||||
|
||||
"""Exception raised when a URL could not be mapped to any handler (404).
|
||||
|
||||
This is equivalent to raising
|
||||
:class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`.
|
||||
This is equivalent to raising :class:`HTTPError("404 Not Found")
|
||||
<cherrypy._cperror.HTTPError>`.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None):
|
||||
@@ -477,8 +475,8 @@ _HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
|
||||
def get_error_page(status, **kwargs):
|
||||
"""Return an HTML page, containing a pretty error response.
|
||||
|
||||
status should be an int or a str.
|
||||
kwargs will be interpolated into the page template.
|
||||
status should be an int or a str. kwargs will be interpolated into
|
||||
the page template.
|
||||
"""
|
||||
try:
|
||||
code, reason, message = _httputil.valid_status(status)
|
||||
@@ -595,8 +593,8 @@ def bare_error(extrabody=None):
|
||||
"""Produce status, headers, body for a critical error.
|
||||
|
||||
Returns a triple without calling any other questionable functions,
|
||||
so it should be as error-free as possible. Call it from an HTTP server
|
||||
if you get errors outside of the request.
|
||||
so it should be as error-free as possible. Call it from an HTTP
|
||||
server if you get errors outside of the request.
|
||||
|
||||
If extrabody is None, a friendly but rather unhelpful error message
|
||||
is set in the body. If extrabody is a string, it will be appended
|
||||
|
||||
@@ -123,7 +123,6 @@ logfmt = logging.Formatter('%(message)s')
|
||||
|
||||
|
||||
class NullHandler(logging.Handler):
|
||||
|
||||
"""A no-op logging handler to silence the logging.lastResort handler."""
|
||||
|
||||
def handle(self, record):
|
||||
@@ -137,15 +136,16 @@ class NullHandler(logging.Handler):
|
||||
|
||||
|
||||
class LogManager(object):
|
||||
|
||||
"""An object to assist both simple and advanced logging.
|
||||
|
||||
``cherrypy.log`` is an instance of this class.
|
||||
"""
|
||||
|
||||
appid = None
|
||||
"""The id() of the Application object which owns this log manager. If this
|
||||
is a global log manager, appid is None."""
|
||||
"""The id() of the Application object which owns this log manager.
|
||||
|
||||
If this is a global log manager, appid is None.
|
||||
"""
|
||||
|
||||
error_log = None
|
||||
"""The actual :class:`logging.Logger` instance for error messages."""
|
||||
@@ -317,8 +317,8 @@ class LogManager(object):
|
||||
def screen(self):
|
||||
"""Turn stderr/stdout logging on or off.
|
||||
|
||||
If you set this to True, it'll add the appropriate StreamHandler for
|
||||
you. If you set it to False, it will remove the handler.
|
||||
If you set this to True, it'll add the appropriate StreamHandler
|
||||
for you. If you set it to False, it will remove the handler.
|
||||
"""
|
||||
h = self._get_builtin_handler
|
||||
has_h = h(self.error_log, 'screen') or h(self.access_log, 'screen')
|
||||
@@ -414,7 +414,6 @@ class LogManager(object):
|
||||
|
||||
|
||||
class WSGIErrorHandler(logging.Handler):
|
||||
|
||||
"A handler class which writes logging records to environ['wsgi.errors']."
|
||||
|
||||
def flush(self):
|
||||
@@ -452,6 +451,8 @@ class WSGIErrorHandler(logging.Handler):
|
||||
|
||||
class LazyRfc3339UtcTime(object):
|
||||
def __str__(self):
|
||||
"""Return now() in RFC3339 UTC Format."""
|
||||
now = datetime.datetime.now()
|
||||
return now.isoformat('T') + 'Z'
|
||||
"""Return datetime in RFC3339 UTC Format."""
|
||||
iso_formatted_now = datetime.datetime.now(
|
||||
datetime.timezone.utc,
|
||||
).isoformat('T')
|
||||
return f'{iso_formatted_now!s}Z'
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Native adapter for serving CherryPy via mod_python
|
||||
"""Native adapter for serving CherryPy via mod_python.
|
||||
|
||||
Basic usage:
|
||||
|
||||
|
||||
@@ -120,10 +120,10 @@ class NativeGateway(cheroot.server.Gateway):
|
||||
class CPHTTPServer(cheroot.server.HTTPServer):
|
||||
"""Wrapper for cheroot.server.HTTPServer.
|
||||
|
||||
cheroot has been designed to not reference CherryPy in any way,
|
||||
so that it can be used in other frameworks and applications.
|
||||
Therefore, we wrap it here, so we can apply some attributes
|
||||
from config -> cherrypy.server -> HTTPServer.
|
||||
cheroot has been designed to not reference CherryPy in any way, so
|
||||
that it can be used in other frameworks and applications. Therefore,
|
||||
we wrap it here, so we can apply some attributes from config ->
|
||||
cherrypy.server -> HTTPServer.
|
||||
"""
|
||||
|
||||
def __init__(self, server_adapter=cherrypy.server):
|
||||
|
||||
@@ -248,7 +248,10 @@ def process_multipart_form_data(entity):
|
||||
|
||||
|
||||
def _old_process_multipart(entity):
|
||||
"""The behavior of 3.2 and lower. Deprecated and will be changed in 3.3."""
|
||||
"""The behavior of 3.2 and lower.
|
||||
|
||||
Deprecated and will be changed in 3.3.
|
||||
"""
|
||||
process_multipart(entity)
|
||||
|
||||
params = entity.params
|
||||
@@ -277,7 +280,6 @@ def _old_process_multipart(entity):
|
||||
|
||||
# -------------------------------- Entities --------------------------------- #
|
||||
class Entity(object):
|
||||
|
||||
"""An HTTP request body, or MIME multipart body.
|
||||
|
||||
This class collects information about the HTTP request entity. When a
|
||||
@@ -346,13 +348,15 @@ class Entity(object):
|
||||
content_type = None
|
||||
"""The value of the Content-Type request header.
|
||||
|
||||
If the Entity is part of a multipart payload, this will be the Content-Type
|
||||
given in the MIME headers for this part.
|
||||
If the Entity is part of a multipart payload, this will be the
|
||||
Content-Type given in the MIME headers for this part.
|
||||
"""
|
||||
|
||||
default_content_type = 'application/x-www-form-urlencoded'
|
||||
"""This defines a default ``Content-Type`` to use if no Content-Type header
|
||||
is given. The empty string is used for RequestBody, which results in the
|
||||
is given.
|
||||
|
||||
The empty string is used for RequestBody, which results in the
|
||||
request body not being read or parsed at all. This is by design; a missing
|
||||
``Content-Type`` header in the HTTP request entity is an error at best,
|
||||
and a security hole at worst. For multipart parts, however, the MIME spec
|
||||
@@ -402,8 +406,8 @@ class Entity(object):
|
||||
part_class = None
|
||||
"""The class used for multipart parts.
|
||||
|
||||
You can replace this with custom subclasses to alter the processing of
|
||||
multipart parts.
|
||||
You can replace this with custom subclasses to alter the processing
|
||||
of multipart parts.
|
||||
"""
|
||||
|
||||
def __init__(self, fp, headers, params=None, parts=None):
|
||||
@@ -509,7 +513,8 @@ class Entity(object):
|
||||
"""Return a file-like object into which the request body will be read.
|
||||
|
||||
By default, this will return a TemporaryFile. Override as needed.
|
||||
See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`."""
|
||||
See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`.
|
||||
"""
|
||||
return tempfile.TemporaryFile()
|
||||
|
||||
def fullvalue(self):
|
||||
@@ -525,7 +530,7 @@ class Entity(object):
|
||||
return value
|
||||
|
||||
def decode_entity(self, value):
|
||||
"""Return a given byte encoded value as a string"""
|
||||
"""Return a given byte encoded value as a string."""
|
||||
for charset in self.attempt_charsets:
|
||||
try:
|
||||
value = value.decode(charset)
|
||||
@@ -569,7 +574,6 @@ class Entity(object):
|
||||
|
||||
|
||||
class Part(Entity):
|
||||
|
||||
"""A MIME part entity, part of a multipart entity."""
|
||||
|
||||
# "The default character set, which must be assumed in the absence of a
|
||||
@@ -653,8 +657,8 @@ class Part(Entity):
|
||||
def read_lines_to_boundary(self, fp_out=None):
|
||||
"""Read bytes from self.fp and return or write them to a file.
|
||||
|
||||
If the 'fp_out' argument is None (the default), all bytes read are
|
||||
returned in a single byte string.
|
||||
If the 'fp_out' argument is None (the default), all bytes read
|
||||
are returned in a single byte string.
|
||||
|
||||
If the 'fp_out' argument is not None, it must be a file-like
|
||||
object that supports the 'write' method; all bytes read will be
|
||||
@@ -755,15 +759,15 @@ class SizedReader:
|
||||
def read(self, size=None, fp_out=None):
|
||||
"""Read bytes from the request body and return or write them to a file.
|
||||
|
||||
A number of bytes less than or equal to the 'size' argument are read
|
||||
off the socket. The actual number of bytes read are tracked in
|
||||
self.bytes_read. The number may be smaller than 'size' when 1) the
|
||||
client sends fewer bytes, 2) the 'Content-Length' request header
|
||||
specifies fewer bytes than requested, or 3) the number of bytes read
|
||||
exceeds self.maxbytes (in which case, 413 is raised).
|
||||
A number of bytes less than or equal to the 'size' argument are
|
||||
read off the socket. The actual number of bytes read are tracked
|
||||
in self.bytes_read. The number may be smaller than 'size' when
|
||||
1) the client sends fewer bytes, 2) the 'Content-Length' request
|
||||
header specifies fewer bytes than requested, or 3) the number of
|
||||
bytes read exceeds self.maxbytes (in which case, 413 is raised).
|
||||
|
||||
If the 'fp_out' argument is None (the default), all bytes read are
|
||||
returned in a single byte string.
|
||||
If the 'fp_out' argument is None (the default), all bytes read
|
||||
are returned in a single byte string.
|
||||
|
||||
If the 'fp_out' argument is not None, it must be a file-like
|
||||
object that supports the 'write' method; all bytes read will be
|
||||
@@ -918,7 +922,6 @@ class SizedReader:
|
||||
|
||||
|
||||
class RequestBody(Entity):
|
||||
|
||||
"""The entity of the HTTP request."""
|
||||
|
||||
bufsize = 8 * 1024
|
||||
|
||||
@@ -16,7 +16,6 @@ from cherrypy.lib import httputil, reprconf, encoding
|
||||
|
||||
|
||||
class Hook(object):
|
||||
|
||||
"""A callback and its metadata: failsafe, priority, and kwargs."""
|
||||
|
||||
callback = None
|
||||
@@ -30,10 +29,12 @@ class Hook(object):
|
||||
from the same call point raise exceptions."""
|
||||
|
||||
priority = 50
|
||||
"""Defines the order of execution for a list of Hooks.
|
||||
|
||||
Priority numbers should be limited to the closed interval [0, 100],
|
||||
but values outside this range are acceptable, as are fractional
|
||||
values.
|
||||
"""
|
||||
Defines the order of execution for a list of Hooks. Priority numbers
|
||||
should be limited to the closed interval [0, 100], but values outside
|
||||
this range are acceptable, as are fractional values."""
|
||||
|
||||
kwargs = {}
|
||||
"""
|
||||
@@ -74,7 +75,6 @@ class Hook(object):
|
||||
|
||||
|
||||
class HookMap(dict):
|
||||
|
||||
"""A map of call points to lists of callbacks (Hook objects)."""
|
||||
|
||||
def __new__(cls, points=None):
|
||||
@@ -190,23 +190,23 @@ hookpoints = ['on_start_resource', 'before_request_body',
|
||||
|
||||
|
||||
class Request(object):
|
||||
|
||||
"""An HTTP request.
|
||||
|
||||
This object represents the metadata of an HTTP request message;
|
||||
that is, it contains attributes which describe the environment
|
||||
in which the request URL, headers, and body were sent (if you
|
||||
want tools to interpret the headers and body, those are elsewhere,
|
||||
mostly in Tools). This 'metadata' consists of socket data,
|
||||
transport characteristics, and the Request-Line. This object
|
||||
also contains data regarding the configuration in effect for
|
||||
the given URL, and the execution plan for generating a response.
|
||||
This object represents the metadata of an HTTP request message; that
|
||||
is, it contains attributes which describe the environment in which
|
||||
the request URL, headers, and body were sent (if you want tools to
|
||||
interpret the headers and body, those are elsewhere, mostly in
|
||||
Tools). This 'metadata' consists of socket data, transport
|
||||
characteristics, and the Request-Line. This object also contains
|
||||
data regarding the configuration in effect for the given URL, and
|
||||
the execution plan for generating a response.
|
||||
"""
|
||||
|
||||
prev = None
|
||||
"""The previous Request object (if any).
|
||||
|
||||
This should be None unless we are processing an InternalRedirect.
|
||||
"""
|
||||
The previous Request object (if any). This should be None
|
||||
unless we are processing an InternalRedirect."""
|
||||
|
||||
# Conversation/connection attributes
|
||||
local = httputil.Host('127.0.0.1', 80)
|
||||
@@ -216,9 +216,10 @@ class Request(object):
|
||||
'An httputil.Host(ip, port, hostname) object for the client socket.'
|
||||
|
||||
scheme = 'http'
|
||||
"""The protocol used between client and server.
|
||||
|
||||
In most cases, this will be either 'http' or 'https'.
|
||||
"""
|
||||
The protocol used between client and server. In most cases,
|
||||
this will be either 'http' or 'https'."""
|
||||
|
||||
server_protocol = 'HTTP/1.1'
|
||||
"""
|
||||
@@ -227,25 +228,30 @@ class Request(object):
|
||||
|
||||
base = ''
|
||||
"""The (scheme://host) portion of the requested URL.
|
||||
|
||||
In some cases (e.g. when proxying via mod_rewrite), this may contain
|
||||
path segments which cherrypy.url uses when constructing url's, but
|
||||
which otherwise are ignored by CherryPy. Regardless, this value
|
||||
MUST NOT end in a slash."""
|
||||
which otherwise are ignored by CherryPy. Regardless, this value MUST
|
||||
NOT end in a slash.
|
||||
"""
|
||||
|
||||
# Request-Line attributes
|
||||
request_line = ''
|
||||
"""The complete Request-Line received from the client.
|
||||
|
||||
This is a single string consisting of the request method, URI, and
|
||||
protocol version (joined by spaces). Any final CRLF is removed.
|
||||
"""
|
||||
The complete Request-Line received from the client. This is a
|
||||
single string consisting of the request method, URI, and protocol
|
||||
version (joined by spaces). Any final CRLF is removed."""
|
||||
|
||||
method = 'GET'
|
||||
"""Indicates the HTTP method to be performed on the resource identified by
|
||||
the Request-URI.
|
||||
|
||||
Common methods include GET, HEAD, POST, PUT, and DELETE. CherryPy
|
||||
allows any extension method; however, various HTTP servers and
|
||||
gateways may restrict the set of allowable methods. CherryPy
|
||||
applications SHOULD restrict the set (on a per-URI basis).
|
||||
"""
|
||||
Indicates the HTTP method to be performed on the resource identified
|
||||
by the Request-URI. Common methods include GET, HEAD, POST, PUT, and
|
||||
DELETE. CherryPy allows any extension method; however, various HTTP
|
||||
servers and gateways may restrict the set of allowable methods.
|
||||
CherryPy applications SHOULD restrict the set (on a per-URI basis)."""
|
||||
|
||||
query_string = ''
|
||||
"""
|
||||
@@ -277,22 +283,26 @@ class Request(object):
|
||||
A dict which combines query string (GET) and request entity (POST)
|
||||
variables. This is populated in two stages: GET params are added
|
||||
before the 'on_start_resource' hook, and POST params are added
|
||||
between the 'before_request_body' and 'before_handler' hooks."""
|
||||
between the 'before_request_body' and 'before_handler' hooks.
|
||||
"""
|
||||
|
||||
# Message attributes
|
||||
header_list = []
|
||||
"""A list of the HTTP request headers as (name, value) tuples.
|
||||
|
||||
In general, you should use request.headers (a dict) instead.
|
||||
"""
|
||||
A list of the HTTP request headers as (name, value) tuples.
|
||||
In general, you should use request.headers (a dict) instead."""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
A dict-like object containing the request headers. Keys are header
|
||||
"""A dict-like object containing the request headers.
|
||||
|
||||
Keys are header
|
||||
names (in Title-Case format); however, you may get and set them in
|
||||
a case-insensitive manner. That is, headers['Content-Type'] and
|
||||
headers['content-type'] refer to the same value. Values are header
|
||||
values (decoded according to :rfc:`2047` if necessary). See also:
|
||||
httputil.HeaderMap, httputil.HeaderElement."""
|
||||
httputil.HeaderMap, httputil.HeaderElement.
|
||||
"""
|
||||
|
||||
cookie = SimpleCookie()
|
||||
"""See help(Cookie)."""
|
||||
@@ -336,7 +346,8 @@ class Request(object):
|
||||
or multipart, this will be None. Otherwise, this will be an instance
|
||||
of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you
|
||||
can .read()); this value is set between the 'before_request_body' and
|
||||
'before_handler' hooks (assuming that process_request_body is True)."""
|
||||
'before_handler' hooks (assuming that process_request_body is True).
|
||||
"""
|
||||
|
||||
# Dispatch attributes
|
||||
dispatch = cherrypy.dispatch.Dispatcher()
|
||||
@@ -347,23 +358,24 @@ class Request(object):
|
||||
calls the dispatcher as early as possible, passing it a 'path_info'
|
||||
argument.
|
||||
|
||||
The default dispatcher discovers the page handler by matching path_info
|
||||
to a hierarchical arrangement of objects, starting at request.app.root.
|
||||
See help(cherrypy.dispatch) for more information."""
|
||||
The default dispatcher discovers the page handler by matching
|
||||
path_info to a hierarchical arrangement of objects, starting at
|
||||
request.app.root. See help(cherrypy.dispatch) for more information.
|
||||
"""
|
||||
|
||||
script_name = ''
|
||||
"""
|
||||
The 'mount point' of the application which is handling this request.
|
||||
"""The 'mount point' of the application which is handling this request.
|
||||
|
||||
This attribute MUST NOT end in a slash. If the script_name refers to
|
||||
the root of the URI, it MUST be an empty string (not "/").
|
||||
"""
|
||||
|
||||
path_info = '/'
|
||||
"""The 'relative path' portion of the Request-URI.
|
||||
|
||||
This is relative to the script_name ('mount point') of the
|
||||
application which is handling this request.
|
||||
"""
|
||||
The 'relative path' portion of the Request-URI. This is relative
|
||||
to the script_name ('mount point') of the application which is
|
||||
handling this request."""
|
||||
|
||||
login = None
|
||||
"""
|
||||
@@ -391,14 +403,16 @@ class Request(object):
|
||||
of the form: {Toolbox.namespace: {Tool.name: config dict}}."""
|
||||
|
||||
config = None
|
||||
"""A flat dict of all configuration entries which apply to the current
|
||||
request.
|
||||
|
||||
These entries are collected from global config, application config
|
||||
(based on request.path_info), and from handler config (exactly how
|
||||
is governed by the request.dispatch object in effect for this
|
||||
request; by default, handler config can be attached anywhere in the
|
||||
tree between request.app.root and the final handler, and inherits
|
||||
downward).
|
||||
"""
|
||||
A flat dict of all configuration entries which apply to the
|
||||
current request. These entries are collected from global config,
|
||||
application config (based on request.path_info), and from handler
|
||||
config (exactly how is governed by the request.dispatch object in
|
||||
effect for this request; by default, handler config can be attached
|
||||
anywhere in the tree between request.app.root and the final handler,
|
||||
and inherits downward)."""
|
||||
|
||||
is_index = None
|
||||
"""
|
||||
@@ -409,13 +423,14 @@ class Request(object):
|
||||
the trailing slash. See cherrypy.tools.trailing_slash."""
|
||||
|
||||
hooks = HookMap(hookpoints)
|
||||
"""
|
||||
A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
|
||||
"""A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
|
||||
|
||||
Each key is a str naming the hook point, and each value is a list
|
||||
of hooks which will be called at that hook point during this request.
|
||||
The list of hooks is generally populated as early as possible (mostly
|
||||
from Tools specified in config), but may be extended at any time.
|
||||
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools."""
|
||||
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools.
|
||||
"""
|
||||
|
||||
error_response = cherrypy.HTTPError(500).set_response
|
||||
"""
|
||||
@@ -428,12 +443,11 @@ class Request(object):
|
||||
error response to the user-agent."""
|
||||
|
||||
error_page = {}
|
||||
"""
|
||||
A dict of {error code: response filename or callable} pairs.
|
||||
"""A dict of {error code: response filename or callable} pairs.
|
||||
|
||||
The error code must be an int representing a given HTTP error code,
|
||||
or the string 'default', which will be used if no matching entry
|
||||
is found for a given numeric code.
|
||||
or the string 'default', which will be used if no matching entry is
|
||||
found for a given numeric code.
|
||||
|
||||
If a filename is provided, the file should contain a Python string-
|
||||
formatting template, and can expect by default to receive format
|
||||
@@ -447,8 +461,8 @@ class Request(object):
|
||||
iterable of strings which will be set to response.body. It may also
|
||||
override headers or perform any other processing.
|
||||
|
||||
If no entry is given for an error code, and no 'default' entry exists,
|
||||
a default template will be used.
|
||||
If no entry is given for an error code, and no 'default' entry
|
||||
exists, a default template will be used.
|
||||
"""
|
||||
|
||||
show_tracebacks = True
|
||||
@@ -473,9 +487,10 @@ class Request(object):
|
||||
"""True once the close method has been called, False otherwise."""
|
||||
|
||||
stage = None
|
||||
"""A string containing the stage reached in the request-handling process.
|
||||
|
||||
This is useful when debugging a live server with hung requests.
|
||||
"""
|
||||
A string containing the stage reached in the request-handling process.
|
||||
This is useful when debugging a live server with hung requests."""
|
||||
|
||||
unique_id = None
|
||||
"""A lazy object generating and memorizing UUID4 on ``str()`` render."""
|
||||
@@ -492,9 +507,10 @@ class Request(object):
|
||||
server_protocol='HTTP/1.1'):
|
||||
"""Populate a new Request object.
|
||||
|
||||
local_host should be an httputil.Host object with the server info.
|
||||
remote_host should be an httputil.Host object with the client info.
|
||||
scheme should be a string, either "http" or "https".
|
||||
local_host should be an httputil.Host object with the server
|
||||
info. remote_host should be an httputil.Host object with the
|
||||
client info. scheme should be a string, either "http" or
|
||||
"https".
|
||||
"""
|
||||
self.local = local_host
|
||||
self.remote = remote_host
|
||||
@@ -514,7 +530,10 @@ class Request(object):
|
||||
self.unique_id = LazyUUID4()
|
||||
|
||||
def close(self):
|
||||
"""Run cleanup code. (Core)"""
|
||||
"""Run cleanup code.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
self.stage = 'on_end_request'
|
||||
@@ -551,7 +570,6 @@ class Request(object):
|
||||
|
||||
Consumer code (HTTP servers) should then access these response
|
||||
attributes to build the outbound stream.
|
||||
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
self.stage = 'run'
|
||||
@@ -631,7 +649,10 @@ class Request(object):
|
||||
return response
|
||||
|
||||
def respond(self, path_info):
|
||||
"""Generate a response for the resource at self.path_info. (Core)"""
|
||||
"""Generate a response for the resource at self.path_info.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
try:
|
||||
@@ -702,7 +723,10 @@ class Request(object):
|
||||
response.finalize()
|
||||
|
||||
def process_query_string(self):
|
||||
"""Parse the query string into Python structures. (Core)"""
|
||||
"""Parse the query string into Python structures.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
try:
|
||||
p = httputil.parse_query_string(
|
||||
self.query_string, encoding=self.query_string_encoding)
|
||||
@@ -715,7 +739,10 @@ class Request(object):
|
||||
self.params.update(p)
|
||||
|
||||
def process_headers(self):
|
||||
"""Parse HTTP header data into Python structures. (Core)"""
|
||||
"""Parse HTTP header data into Python structures.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
# Process the headers into self.headers
|
||||
headers = self.headers
|
||||
for name, value in self.header_list:
|
||||
@@ -751,7 +778,10 @@ class Request(object):
|
||||
self.base = '%s://%s' % (self.scheme, host)
|
||||
|
||||
def get_resource(self, path):
|
||||
"""Call a dispatcher (which sets self.handler and .config). (Core)"""
|
||||
"""Call a dispatcher (which sets self.handler and .config).
|
||||
|
||||
(Core)
|
||||
"""
|
||||
# First, see if there is a custom dispatch at this URI. Custom
|
||||
# dispatchers can only be specified in app.config, not in _cp_config
|
||||
# (since custom dispatchers may not even have an app.root).
|
||||
@@ -762,7 +792,10 @@ class Request(object):
|
||||
dispatch(path)
|
||||
|
||||
def handle_error(self):
|
||||
"""Handle the last unanticipated exception. (Core)"""
|
||||
"""Handle the last unanticipated exception.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
try:
|
||||
self.hooks.run('before_error_response')
|
||||
if self.error_response:
|
||||
@@ -776,7 +809,6 @@ class Request(object):
|
||||
|
||||
|
||||
class ResponseBody(object):
|
||||
|
||||
"""The body of the HTTP response (the response entity)."""
|
||||
|
||||
unicode_err = ('Page handlers MUST return bytes. Use tools.encode '
|
||||
@@ -802,18 +834,18 @@ class ResponseBody(object):
|
||||
|
||||
|
||||
class Response(object):
|
||||
|
||||
"""An HTTP Response, including status, headers, and body."""
|
||||
|
||||
status = ''
|
||||
"""The HTTP Status-Code and Reason-Phrase."""
|
||||
|
||||
header_list = []
|
||||
"""
|
||||
A list of the HTTP response headers as (name, value) tuples.
|
||||
"""A list of the HTTP response headers as (name, value) tuples.
|
||||
|
||||
In general, you should use response.headers (a dict) instead. This
|
||||
attribute is generated from response.headers and is not valid until
|
||||
after the finalize phase."""
|
||||
after the finalize phase.
|
||||
"""
|
||||
|
||||
headers = httputil.HeaderMap()
|
||||
"""
|
||||
@@ -833,7 +865,10 @@ class Response(object):
|
||||
"""The body (entity) of the HTTP response."""
|
||||
|
||||
time = None
|
||||
"""The value of time.time() when created. Use in HTTP dates."""
|
||||
"""The value of time.time() when created.
|
||||
|
||||
Use in HTTP dates.
|
||||
"""
|
||||
|
||||
stream = False
|
||||
"""If False, buffer the response body."""
|
||||
@@ -861,15 +896,15 @@ class Response(object):
|
||||
return new_body
|
||||
|
||||
def _flush_body(self):
|
||||
"""
|
||||
Discard self.body but consume any generator such that
|
||||
any finalization can occur, such as is required by
|
||||
caching.tee_output().
|
||||
"""
|
||||
"""Discard self.body but consume any generator such that any
|
||||
finalization can occur, such as is required by caching.tee_output()."""
|
||||
consume(iter(self.body))
|
||||
|
||||
def finalize(self):
|
||||
"""Transform headers (and cookies) into self.header_list. (Core)"""
|
||||
"""Transform headers (and cookies) into self.header_list.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
try:
|
||||
code, reason, _ = httputil.valid_status(self.status)
|
||||
except ValueError:
|
||||
|
||||
@@ -50,7 +50,8 @@ class Server(ServerAdapter):
|
||||
"""If given, the name of the UNIX socket to use instead of TCP/IP.
|
||||
|
||||
When this option is not None, the `socket_host` and `socket_port` options
|
||||
are ignored."""
|
||||
are ignored.
|
||||
"""
|
||||
|
||||
socket_queue_size = 5
|
||||
"""The 'backlog' argument to socket.listen(); specifies the maximum number
|
||||
@@ -79,17 +80,24 @@ class Server(ServerAdapter):
|
||||
"""The number of worker threads to start up in the pool."""
|
||||
|
||||
thread_pool_max = -1
|
||||
"""The maximum size of the worker-thread pool. Use -1 to indicate no limit.
|
||||
"""The maximum size of the worker-thread pool.
|
||||
|
||||
Use -1 to indicate no limit.
|
||||
"""
|
||||
|
||||
max_request_header_size = 500 * 1024
|
||||
"""The maximum number of bytes allowable in the request headers.
|
||||
If exceeded, the HTTP server should return "413 Request Entity Too Large".
|
||||
|
||||
If exceeded, the HTTP server should return "413 Request Entity Too
|
||||
Large".
|
||||
"""
|
||||
|
||||
max_request_body_size = 100 * 1024 * 1024
|
||||
"""The maximum number of bytes allowable in the request body. If exceeded,
|
||||
the HTTP server should return "413 Request Entity Too Large"."""
|
||||
"""The maximum number of bytes allowable in the request body.
|
||||
|
||||
If exceeded, the HTTP server should return "413 Request Entity Too
|
||||
Large".
|
||||
"""
|
||||
|
||||
instance = None
|
||||
"""If not None, this should be an HTTP server instance (such as
|
||||
@@ -119,7 +127,8 @@ class Server(ServerAdapter):
|
||||
the builtin WSGI server. Builtin options are: 'builtin' (to
|
||||
use the SSL library built into recent versions of Python).
|
||||
You may also register your own classes in the
|
||||
cheroot.server.ssl_adapters dict."""
|
||||
cheroot.server.ssl_adapters dict.
|
||||
"""
|
||||
|
||||
statistics = False
|
||||
"""Turns statistics-gathering on or off for aware HTTP servers."""
|
||||
@@ -129,11 +138,13 @@ class Server(ServerAdapter):
|
||||
|
||||
wsgi_version = (1, 0)
|
||||
"""The WSGI version tuple to use with the builtin WSGI server.
|
||||
The provided options are (1, 0) [which includes support for PEP 3333,
|
||||
which declares it covers WSGI version 1.0.1 but still mandates the
|
||||
wsgi.version (1, 0)] and ('u', 0), an experimental unicode version.
|
||||
You may create and register your own experimental versions of the WSGI
|
||||
protocol by adding custom classes to the cheroot.server.wsgi_gateways dict.
|
||||
|
||||
The provided options are (1, 0) [which includes support for PEP
|
||||
3333, which declares it covers WSGI version 1.0.1 but still mandates
|
||||
the wsgi.version (1, 0)] and ('u', 0), an experimental unicode
|
||||
version. You may create and register your own experimental versions
|
||||
of the WSGI protocol by adding custom classes to the
|
||||
cheroot.server.wsgi_gateways dict.
|
||||
"""
|
||||
|
||||
peercreds = False
|
||||
@@ -184,7 +195,8 @@ class Server(ServerAdapter):
|
||||
def bind_addr(self):
|
||||
"""Return bind address.
|
||||
|
||||
A (host, port) tuple for TCP sockets or a str for Unix domain sockts.
|
||||
A (host, port) tuple for TCP sockets or a str for Unix domain
|
||||
sockets.
|
||||
"""
|
||||
if self.socket_file:
|
||||
return self.socket_file
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""CherryPy tools. A "tool" is any helper, adapted to CP.
|
||||
|
||||
Tools are usually designed to be used in a variety of ways (although some
|
||||
may only offer one if they choose):
|
||||
Tools are usually designed to be used in a variety of ways (although
|
||||
some may only offer one if they choose):
|
||||
|
||||
Library calls
|
||||
All tools are callables that can be used wherever needed.
|
||||
@@ -48,10 +48,10 @@ _attr_error = (
|
||||
|
||||
|
||||
class Tool(object):
|
||||
|
||||
"""A registered function for use with CherryPy request-processing hooks.
|
||||
|
||||
help(tool.callable) should give you more information about this Tool.
|
||||
help(tool.callable) should give you more information about this
|
||||
Tool.
|
||||
"""
|
||||
|
||||
namespace = 'tools'
|
||||
@@ -135,8 +135,8 @@ class Tool(object):
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
The standard CherryPy request object will automatically call
|
||||
this method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop('priority', None)
|
||||
@@ -147,15 +147,15 @@ class Tool(object):
|
||||
|
||||
|
||||
class HandlerTool(Tool):
|
||||
|
||||
"""Tool which is called 'before main', that may skip normal handlers.
|
||||
|
||||
If the tool successfully handles the request (by setting response.body),
|
||||
if should return True. This will cause CherryPy to skip any 'normal' page
|
||||
handler. If the tool did not handle the request, it should return False
|
||||
to tell CherryPy to continue on and call the normal page handler. If the
|
||||
tool is declared AS a page handler (see the 'handler' method), returning
|
||||
False will raise NotFound.
|
||||
If the tool successfully handles the request (by setting
|
||||
response.body), if should return True. This will cause CherryPy to
|
||||
skip any 'normal' page handler. If the tool did not handle the
|
||||
request, it should return False to tell CherryPy to continue on and
|
||||
call the normal page handler. If the tool is declared AS a page
|
||||
handler (see the 'handler' method), returning False will raise
|
||||
NotFound.
|
||||
"""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
@@ -185,8 +185,8 @@ class HandlerTool(Tool):
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
The standard CherryPy request object will automatically call
|
||||
this method when the tool is "turned on" in config.
|
||||
"""
|
||||
conf = self._merged_args()
|
||||
p = conf.pop('priority', None)
|
||||
@@ -197,7 +197,6 @@ class HandlerTool(Tool):
|
||||
|
||||
|
||||
class HandlerWrapperTool(Tool):
|
||||
|
||||
"""Tool which wraps request.handler in a provided wrapper function.
|
||||
|
||||
The 'newhandler' arg must be a handler wrapper function that takes a
|
||||
@@ -232,7 +231,6 @@ class HandlerWrapperTool(Tool):
|
||||
|
||||
|
||||
class ErrorTool(Tool):
|
||||
|
||||
"""Tool which is used to replace the default request.error_response."""
|
||||
|
||||
def __init__(self, callable, name=None):
|
||||
@@ -244,8 +242,8 @@ class ErrorTool(Tool):
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
The standard CherryPy request object will automatically call
|
||||
this method when the tool is "turned on" in config.
|
||||
"""
|
||||
cherrypy.serving.request.error_response = self._wrapper
|
||||
|
||||
@@ -254,7 +252,6 @@ class ErrorTool(Tool):
|
||||
|
||||
|
||||
class SessionTool(Tool):
|
||||
|
||||
"""Session Tool for CherryPy.
|
||||
|
||||
sessions.locking
|
||||
@@ -282,8 +279,8 @@ class SessionTool(Tool):
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
The standard CherryPy request object will automatically call
|
||||
this method when the tool is "turned on" in config.
|
||||
"""
|
||||
hooks = cherrypy.serving.request.hooks
|
||||
|
||||
@@ -325,7 +322,6 @@ class SessionTool(Tool):
|
||||
|
||||
|
||||
class XMLRPCController(object):
|
||||
|
||||
"""A Controller (page handler collection) for XML-RPC.
|
||||
|
||||
To use it, have your controllers subclass this base class (it will
|
||||
@@ -392,7 +388,6 @@ class SessionAuthTool(HandlerTool):
|
||||
|
||||
|
||||
class CachingTool(Tool):
|
||||
|
||||
"""Caching Tool for CherryPy."""
|
||||
|
||||
def _wrapper(self, **kwargs):
|
||||
@@ -416,11 +411,11 @@ class CachingTool(Tool):
|
||||
|
||||
|
||||
class Toolbox(object):
|
||||
|
||||
"""A collection of Tools.
|
||||
|
||||
This object also functions as a config namespace handler for itself.
|
||||
Custom toolboxes should be added to each Application's toolboxes dict.
|
||||
Custom toolboxes should be added to each Application's toolboxes
|
||||
dict.
|
||||
"""
|
||||
|
||||
def __init__(self, namespace):
|
||||
|
||||
@@ -10,19 +10,22 @@ from cherrypy.lib import httputil, reprconf
|
||||
class Application(object):
|
||||
"""A CherryPy Application.
|
||||
|
||||
Servers and gateways should not instantiate Request objects directly.
|
||||
Instead, they should ask an Application object for a request object.
|
||||
Servers and gateways should not instantiate Request objects
|
||||
directly. Instead, they should ask an Application object for a
|
||||
request object.
|
||||
|
||||
An instance of this class may also be used as a WSGI callable
|
||||
(WSGI application object) for itself.
|
||||
An instance of this class may also be used as a WSGI callable (WSGI
|
||||
application object) for itself.
|
||||
"""
|
||||
|
||||
root = None
|
||||
"""The top-most container of page handlers for this app. Handlers should
|
||||
be arranged in a hierarchy of attributes, matching the expected URI
|
||||
hierarchy; the default dispatcher then searches this hierarchy for a
|
||||
matching handler. When using a dispatcher other than the default,
|
||||
this value may be None."""
|
||||
"""The top-most container of page handlers for this app.
|
||||
|
||||
Handlers should be arranged in a hierarchy of attributes, matching
|
||||
the expected URI hierarchy; the default dispatcher then searches
|
||||
this hierarchy for a matching handler. When using a dispatcher other
|
||||
than the default, this value may be None.
|
||||
"""
|
||||
|
||||
config = {}
|
||||
"""A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict
|
||||
@@ -32,10 +35,16 @@ class Application(object):
|
||||
toolboxes = {'tools': cherrypy.tools}
|
||||
|
||||
log = None
|
||||
"""A LogManager instance. See _cplogging."""
|
||||
"""A LogManager instance.
|
||||
|
||||
See _cplogging.
|
||||
"""
|
||||
|
||||
wsgiapp = None
|
||||
"""A CPWSGIApp instance. See _cpwsgi."""
|
||||
"""A CPWSGIApp instance.
|
||||
|
||||
See _cpwsgi.
|
||||
"""
|
||||
|
||||
request_class = _cprequest.Request
|
||||
response_class = _cprequest.Response
|
||||
@@ -82,12 +91,15 @@ class Application(object):
|
||||
def script_name(self): # noqa: D401; irrelevant for properties
|
||||
"""The URI "mount point" for this app.
|
||||
|
||||
A mount point is that portion of the URI which is constant for all URIs
|
||||
that are serviced by this application; it does not include scheme,
|
||||
host, or proxy ("virtual host") portions of the URI.
|
||||
A mount point is that portion of the URI which is constant for
|
||||
all URIs that are serviced by this application; it does not
|
||||
include scheme, host, or proxy ("virtual host") portions of the
|
||||
URI.
|
||||
|
||||
For example, if script_name is "/my/cool/app", then the URL
|
||||
"http://www.example.com/my/cool/app/page1" might be handled by a
|
||||
For example, if script_name is "/my/cool/app", then the URL "
|
||||
|
||||
http://www.example.com/my/cool/app/page1"
|
||||
might be handled by a
|
||||
"page1" method on the root object.
|
||||
|
||||
The value of script_name MUST NOT end in a slash. If the script_name
|
||||
@@ -171,9 +183,9 @@ class Application(object):
|
||||
class Tree(object):
|
||||
"""A registry of CherryPy applications, mounted at diverse points.
|
||||
|
||||
An instance of this class may also be used as a WSGI callable
|
||||
(WSGI application object), in which case it dispatches to all
|
||||
mounted apps.
|
||||
An instance of this class may also be used as a WSGI callable (WSGI
|
||||
application object), in which case it dispatches to all mounted
|
||||
apps.
|
||||
"""
|
||||
|
||||
apps = {}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""WSGI interface (see PEP 333 and 3333).
|
||||
|
||||
Note that WSGI environ keys and values are 'native strings'; that is,
|
||||
whatever the type of "" is. For Python 2, that's a byte string; for Python 3,
|
||||
it's a unicode string. But PEP 3333 says: "even if Python's str type is
|
||||
actually Unicode "under the hood", the content of native strings must
|
||||
still be translatable to bytes via the Latin-1 encoding!"
|
||||
whatever the type of "" is. For Python 2, that's a byte string; for
|
||||
Python 3, it's a unicode string. But PEP 3333 says: "even if Python's
|
||||
str type is actually Unicode "under the hood", the content of native
|
||||
strings must still be translatable to bytes via the Latin-1 encoding!"
|
||||
"""
|
||||
|
||||
import sys as _sys
|
||||
@@ -34,7 +34,6 @@ def downgrade_wsgi_ux_to_1x(environ):
|
||||
|
||||
|
||||
class VirtualHost(object):
|
||||
|
||||
"""Select a different WSGI application based on the Host header.
|
||||
|
||||
This can be useful when running multiple sites within one CP server.
|
||||
@@ -56,7 +55,10 @@ class VirtualHost(object):
|
||||
cherrypy.tree.graft(vhost)
|
||||
"""
|
||||
default = None
|
||||
"""Required. The default WSGI application."""
|
||||
"""Required.
|
||||
|
||||
The default WSGI application.
|
||||
"""
|
||||
|
||||
use_x_forwarded_host = True
|
||||
"""If True (the default), any "X-Forwarded-Host"
|
||||
@@ -65,11 +67,12 @@ class VirtualHost(object):
|
||||
|
||||
domains = {}
|
||||
"""A dict of {host header value: application} pairs.
|
||||
The incoming "Host" request header is looked up in this dict,
|
||||
and, if a match is found, the corresponding WSGI application
|
||||
will be called instead of the default. Note that you often need
|
||||
separate entries for "example.com" and "www.example.com".
|
||||
In addition, "Host" headers may contain the port number.
|
||||
|
||||
The incoming "Host" request header is looked up in this dict, and,
|
||||
if a match is found, the corresponding WSGI application will be
|
||||
called instead of the default. Note that you often need separate
|
||||
entries for "example.com" and "www.example.com". In addition, "Host"
|
||||
headers may contain the port number.
|
||||
"""
|
||||
|
||||
def __init__(self, default, domains=None, use_x_forwarded_host=True):
|
||||
@@ -89,7 +92,6 @@ class VirtualHost(object):
|
||||
|
||||
|
||||
class InternalRedirector(object):
|
||||
|
||||
"""WSGI middleware that handles raised cherrypy.InternalRedirect."""
|
||||
|
||||
def __init__(self, nextapp, recursive=False):
|
||||
@@ -137,7 +139,6 @@ class InternalRedirector(object):
|
||||
|
||||
|
||||
class ExceptionTrapper(object):
|
||||
|
||||
"""WSGI middleware that traps exceptions."""
|
||||
|
||||
def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)):
|
||||
@@ -226,7 +227,6 @@ class _TrappedResponse(object):
|
||||
|
||||
|
||||
class AppResponse(object):
|
||||
|
||||
"""WSGI response iterable for CherryPy applications."""
|
||||
|
||||
def __init__(self, environ, start_response, cpapp):
|
||||
@@ -277,7 +277,10 @@ class AppResponse(object):
|
||||
return next(self.iter_response)
|
||||
|
||||
def close(self):
|
||||
"""Close and de-reference the current request and response. (Core)"""
|
||||
"""Close and de-reference the current request and response.
|
||||
|
||||
(Core)
|
||||
"""
|
||||
streaming = _cherrypy.serving.response.stream
|
||||
self.cpapp.release_serving()
|
||||
|
||||
@@ -380,18 +383,20 @@ class AppResponse(object):
|
||||
|
||||
|
||||
class CPWSGIApp(object):
|
||||
|
||||
"""A WSGI application object for a CherryPy Application."""
|
||||
|
||||
pipeline = [
|
||||
('ExceptionTrapper', ExceptionTrapper),
|
||||
('InternalRedirector', InternalRedirector),
|
||||
]
|
||||
"""A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a
|
||||
constructor that takes an initial, positional 'nextapp' argument,
|
||||
plus optional keyword arguments, and returns a WSGI application
|
||||
(that takes environ and start_response arguments). The 'name' can
|
||||
be any you choose, and will correspond to keys in self.config."""
|
||||
"""A list of (name, wsgiapp) pairs.
|
||||
|
||||
Each 'wsgiapp' MUST be a constructor that takes an initial,
|
||||
positional 'nextapp' argument, plus optional keyword arguments, and
|
||||
returns a WSGI application (that takes environ and start_response
|
||||
arguments). The 'name' can be any you choose, and will correspond to
|
||||
keys in self.config.
|
||||
"""
|
||||
|
||||
head = None
|
||||
"""Rather than nest all apps in the pipeline on each call, it's only
|
||||
@@ -399,9 +404,12 @@ class CPWSGIApp(object):
|
||||
this to None again if you change self.pipeline after calling self."""
|
||||
|
||||
config = {}
|
||||
"""A dict whose keys match names listed in the pipeline. Each
|
||||
value is a further dict which will be passed to the corresponding
|
||||
named WSGI callable (from the pipeline) as keyword arguments."""
|
||||
"""A dict whose keys match names listed in the pipeline.
|
||||
|
||||
Each value is a further dict which will be passed to the
|
||||
corresponding named WSGI callable (from the pipeline) as keyword
|
||||
arguments.
|
||||
"""
|
||||
|
||||
response_class = AppResponse
|
||||
"""The class to instantiate and return as the next app in the WSGI chain.
|
||||
@@ -417,8 +425,8 @@ class CPWSGIApp(object):
|
||||
def tail(self, environ, start_response):
|
||||
"""WSGI application callable for the actual CherryPy application.
|
||||
|
||||
You probably shouldn't call this; call self.__call__ instead,
|
||||
so that any WSGI middleware in self.pipeline can run first.
|
||||
You probably shouldn't call this; call self.__call__ instead, so
|
||||
that any WSGI middleware in self.pipeline can run first.
|
||||
"""
|
||||
return self.response_class(environ, start_response, self.cpapp)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
WSGI server interface (see PEP 333).
|
||||
"""WSGI server interface (see PEP 333).
|
||||
|
||||
This adds some CP-specific bits to the framework-agnostic cheroot package.
|
||||
This adds some CP-specific bits to the framework-agnostic cheroot
|
||||
package.
|
||||
"""
|
||||
import sys
|
||||
|
||||
@@ -35,10 +35,11 @@ class CPWSGIHTTPRequest(cheroot.server.HTTPRequest):
|
||||
class CPWSGIServer(cheroot.wsgi.Server):
|
||||
"""Wrapper for cheroot.wsgi.Server.
|
||||
|
||||
cheroot has been designed to not reference CherryPy in any way,
|
||||
so that it can be used in other frameworks and applications. Therefore,
|
||||
we wrap it here, so we can set our own mount points from cherrypy.tree
|
||||
and apply some attributes from config -> cherrypy.server -> wsgi.Server.
|
||||
cheroot has been designed to not reference CherryPy in any way, so
|
||||
that it can be used in other frameworks and applications. Therefore,
|
||||
we wrap it here, so we can set our own mount points from
|
||||
cherrypy.tree and apply some attributes from config ->
|
||||
cherrypy.server -> wsgi.Server.
|
||||
"""
|
||||
|
||||
fmt = 'CherryPy/{cherrypy.__version__} {cheroot.wsgi.Server.version}'
|
||||
|
||||
@@ -137,7 +137,6 @@ def popargs(*args, **kwargs):
|
||||
class Root:
|
||||
def index(self):
|
||||
#...
|
||||
|
||||
"""
|
||||
# Since keyword arg comes after *args, we have to process it ourselves
|
||||
# for lower versions of python.
|
||||
@@ -201,16 +200,17 @@ def url(path='', qs='', script_name=None, base=None, relative=None):
|
||||
If it does not start with a slash, this returns
|
||||
(base + script_name [+ request.path_info] + path + qs).
|
||||
|
||||
If script_name is None, cherrypy.request will be used
|
||||
to find a script_name, if available.
|
||||
If script_name is None, cherrypy.request will be used to find a
|
||||
script_name, if available.
|
||||
|
||||
If base is None, cherrypy.request.base will be used (if available).
|
||||
Note that you can use cherrypy.tools.proxy to change this.
|
||||
|
||||
Finally, note that this function can be used to obtain an absolute URL
|
||||
for the current request path (minus the querystring) by passing no args.
|
||||
If you call url(qs=cherrypy.request.query_string), you should get the
|
||||
original browser URL (assuming no internal redirections).
|
||||
Finally, note that this function can be used to obtain an absolute
|
||||
URL for the current request path (minus the querystring) by passing
|
||||
no args. If you call url(qs=cherrypy.request.query_string), you
|
||||
should get the original browser URL (assuming no internal
|
||||
redirections).
|
||||
|
||||
If relative is None or not provided, request.app.relative_urls will
|
||||
be used (if available, else False). If False, the output will be an
|
||||
@@ -320,8 +320,8 @@ def normalize_path(path):
|
||||
class _ClassPropertyDescriptor(object):
|
||||
"""Descript for read-only class-based property.
|
||||
|
||||
Turns a classmethod-decorated func into a read-only property of that class
|
||||
type (means the value cannot be set).
|
||||
Turns a classmethod-decorated func into a read-only property of that
|
||||
class type (means the value cannot be set).
|
||||
"""
|
||||
|
||||
def __init__(self, fget, fset=None):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""
|
||||
JSON support.
|
||||
"""JSON support.
|
||||
|
||||
Expose preferred json module as json and provide encode/decode
|
||||
convenience functions.
|
||||
|
||||
@@ -6,8 +6,8 @@ def is_iterator(obj):
|
||||
|
||||
(i.e. like a generator).
|
||||
|
||||
This will return False for objects which are iterable,
|
||||
but not iterators themselves.
|
||||
This will return False for objects which are iterable, but not
|
||||
iterators themselves.
|
||||
"""
|
||||
from types import GeneratorType
|
||||
if isinstance(obj, GeneratorType):
|
||||
|
||||
@@ -18,7 +18,6 @@ as the credentials store::
|
||||
'tools.auth_basic.accept_charset': 'UTF-8',
|
||||
}
|
||||
app_config = { '/' : basic_auth }
|
||||
|
||||
"""
|
||||
|
||||
import binascii
|
||||
|
||||
@@ -55,7 +55,7 @@ def TRACE(msg):
|
||||
|
||||
|
||||
def get_ha1_dict_plain(user_password_dict):
|
||||
"""Returns a get_ha1 function which obtains a plaintext password from a
|
||||
"""Return a get_ha1 function which obtains a plaintext password from a
|
||||
dictionary of the form: {username : password}.
|
||||
|
||||
If you want a simple dictionary-based authentication scheme, with plaintext
|
||||
@@ -72,7 +72,7 @@ def get_ha1_dict_plain(user_password_dict):
|
||||
|
||||
|
||||
def get_ha1_dict(user_ha1_dict):
|
||||
"""Returns a get_ha1 function which obtains a HA1 password hash from a
|
||||
"""Return a get_ha1 function which obtains a HA1 password hash from a
|
||||
dictionary of the form: {username : HA1}.
|
||||
|
||||
If you want a dictionary-based authentication scheme, but with
|
||||
@@ -87,7 +87,7 @@ def get_ha1_dict(user_ha1_dict):
|
||||
|
||||
|
||||
def get_ha1_file_htdigest(filename):
|
||||
"""Returns a get_ha1 function which obtains a HA1 password hash from a
|
||||
"""Return a get_ha1 function which obtains a HA1 password hash from a
|
||||
flat file with lines of the same format as that produced by the Apache
|
||||
htdigest utility. For example, for realm 'wonderland', username 'alice',
|
||||
and password '4x5istwelve', the htdigest line would be::
|
||||
@@ -135,7 +135,7 @@ def synthesize_nonce(s, key, timestamp=None):
|
||||
|
||||
|
||||
def H(s):
|
||||
"""The hash function H"""
|
||||
"""The hash function H."""
|
||||
return md5_hex(s)
|
||||
|
||||
|
||||
@@ -259,10 +259,11 @@ class HttpDigestAuthorization(object):
|
||||
return False
|
||||
|
||||
def is_nonce_stale(self, max_age_seconds=600):
|
||||
"""Returns True if a validated nonce is stale. The nonce contains a
|
||||
timestamp in plaintext and also a secure hash of the timestamp.
|
||||
You should first validate the nonce to ensure the plaintext
|
||||
timestamp is not spoofed.
|
||||
"""Return True if a validated nonce is stale.
|
||||
|
||||
The nonce contains a timestamp in plaintext and also a secure
|
||||
hash of the timestamp. You should first validate the nonce to
|
||||
ensure the plaintext timestamp is not spoofed.
|
||||
"""
|
||||
try:
|
||||
timestamp, hashpart = self.nonce.split(':', 1)
|
||||
@@ -275,7 +276,10 @@ class HttpDigestAuthorization(object):
|
||||
return True
|
||||
|
||||
def HA2(self, entity_body=''):
|
||||
"""Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3."""
|
||||
"""Return the H(A2) string.
|
||||
|
||||
See :rfc:`2617` section 3.2.2.3.
|
||||
"""
|
||||
# RFC 2617 3.2.2.3
|
||||
# If the "qop" directive's value is "auth" or is unspecified,
|
||||
# then A2 is:
|
||||
@@ -306,7 +310,6 @@ class HttpDigestAuthorization(object):
|
||||
4.3. This refers to the entity the user agent sent in the
|
||||
request which has the Authorization header. Typically GET
|
||||
requests don't have an entity, and POST requests do.
|
||||
|
||||
"""
|
||||
ha2 = self.HA2(entity_body)
|
||||
# Request-Digest -- RFC 2617 3.2.2.1
|
||||
@@ -395,7 +398,6 @@ def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'):
|
||||
key
|
||||
A secret string known only to the server, used in the synthesis
|
||||
of nonces.
|
||||
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
|
||||
@@ -447,9 +449,7 @@ def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'):
|
||||
|
||||
|
||||
def _respond_401(realm, key, accept_charset, debug, **kwargs):
|
||||
"""
|
||||
Respond with 401 status and a WWW-Authenticate header
|
||||
"""
|
||||
"""Respond with 401 status and a WWW-Authenticate header."""
|
||||
header = www_authenticate(
|
||||
realm, key,
|
||||
accept_charset=accept_charset,
|
||||
|
||||
@@ -42,7 +42,6 @@ from cherrypy.lib import cptools, httputil
|
||||
|
||||
|
||||
class Cache(object):
|
||||
|
||||
"""Base class for Cache implementations."""
|
||||
|
||||
def get(self):
|
||||
@@ -64,17 +63,16 @@ class Cache(object):
|
||||
|
||||
# ------------------------------ Memory Cache ------------------------------- #
|
||||
class AntiStampedeCache(dict):
|
||||
|
||||
"""A storage system for cached items which reduces stampede collisions."""
|
||||
|
||||
def wait(self, key, timeout=5, debug=False):
|
||||
"""Return the cached value for the given key, or None.
|
||||
|
||||
If timeout is not None, and the value is already
|
||||
being calculated by another thread, wait until the given timeout has
|
||||
elapsed. If the value is available before the timeout expires, it is
|
||||
returned. If not, None is returned, and a sentinel placed in the cache
|
||||
to signal other threads to wait.
|
||||
If timeout is not None, and the value is already being
|
||||
calculated by another thread, wait until the given timeout has
|
||||
elapsed. If the value is available before the timeout expires,
|
||||
it is returned. If not, None is returned, and a sentinel placed
|
||||
in the cache to signal other threads to wait.
|
||||
|
||||
If timeout is None, no waiting is performed nor sentinels used.
|
||||
"""
|
||||
@@ -127,7 +125,6 @@ class AntiStampedeCache(dict):
|
||||
|
||||
|
||||
class MemoryCache(Cache):
|
||||
|
||||
"""An in-memory cache for varying response content.
|
||||
|
||||
Each key in self.store is a URI, and each value is an AntiStampedeCache.
|
||||
@@ -381,7 +378,10 @@ def get(invalid_methods=('POST', 'PUT', 'DELETE'), debug=False, **kwargs):
|
||||
|
||||
|
||||
def tee_output():
|
||||
"""Tee response output to cache storage. Internal."""
|
||||
"""Tee response output to cache storage.
|
||||
|
||||
Internal.
|
||||
"""
|
||||
# Used by CachingTool by attaching to request.hooks
|
||||
|
||||
request = cherrypy.serving.request
|
||||
@@ -441,7 +441,6 @@ def expires(secs=0, force=False, debug=False):
|
||||
* Expires
|
||||
|
||||
If any are already present, none of the above response headers are set.
|
||||
|
||||
"""
|
||||
|
||||
response = cherrypy.serving.response
|
||||
|
||||
@@ -184,7 +184,6 @@ To report statistics::
|
||||
To format statistics reports::
|
||||
|
||||
See 'Reporting', above.
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -254,7 +253,6 @@ def proc_time(s):
|
||||
|
||||
|
||||
class ByteCountWrapper(object):
|
||||
|
||||
"""Wraps a file-like object, counting the number of bytes read."""
|
||||
|
||||
def __init__(self, rfile):
|
||||
@@ -307,7 +305,6 @@ def _get_threading_ident():
|
||||
|
||||
|
||||
class StatsTool(cherrypy.Tool):
|
||||
|
||||
"""Record various information about the current request."""
|
||||
|
||||
def __init__(self):
|
||||
@@ -316,8 +313,8 @@ class StatsTool(cherrypy.Tool):
|
||||
def _setup(self):
|
||||
"""Hook this tool into cherrypy.request.
|
||||
|
||||
The standard CherryPy request object will automatically call this
|
||||
method when the tool is "turned on" in config.
|
||||
The standard CherryPy request object will automatically call
|
||||
this method when the tool is "turned on" in config.
|
||||
"""
|
||||
if appstats.get('Enabled', False):
|
||||
cherrypy.Tool._setup(self)
|
||||
|
||||
@@ -94,8 +94,8 @@ def validate_etags(autotags=False, debug=False):
|
||||
def validate_since():
|
||||
"""Validate the current Last-Modified against If-Modified-Since headers.
|
||||
|
||||
If no code has set the Last-Modified response header, then no validation
|
||||
will be performed.
|
||||
If no code has set the Last-Modified response header, then no
|
||||
validation will be performed.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
lastmod = response.headers.get('Last-Modified')
|
||||
@@ -123,9 +123,9 @@ def validate_since():
|
||||
def allow(methods=None, debug=False):
|
||||
"""Raise 405 if request.method not in methods (default ['GET', 'HEAD']).
|
||||
|
||||
The given methods are case-insensitive, and may be in any order.
|
||||
If only one method is allowed, you may supply a single string;
|
||||
if more than one, supply a list of strings.
|
||||
The given methods are case-insensitive, and may be in any order. If
|
||||
only one method is allowed, you may supply a single string; if more
|
||||
than one, supply a list of strings.
|
||||
|
||||
Regardless of whether the current method is allowed or not, this
|
||||
also emits an 'Allow' response header, containing the given methods.
|
||||
@@ -154,22 +154,23 @@ def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For',
|
||||
scheme='X-Forwarded-Proto', debug=False):
|
||||
"""Change the base URL (scheme://host[:port][/path]).
|
||||
|
||||
For running a CP server behind Apache, lighttpd, or other HTTP server.
|
||||
For running a CP server behind Apache, lighttpd, or other HTTP
|
||||
server.
|
||||
|
||||
For Apache and lighttpd, you should leave the 'local' argument at the
|
||||
default value of 'X-Forwarded-Host'. For Squid, you probably want to set
|
||||
tools.proxy.local = 'Origin'.
|
||||
For Apache and lighttpd, you should leave the 'local' argument at
|
||||
the default value of 'X-Forwarded-Host'. For Squid, you probably
|
||||
want to set tools.proxy.local = 'Origin'.
|
||||
|
||||
If you want the new request.base to include path info (not just the host),
|
||||
you must explicitly set base to the full base path, and ALSO set 'local'
|
||||
to '', so that the X-Forwarded-Host request header (which never includes
|
||||
path info) does not override it. Regardless, the value for 'base' MUST
|
||||
NOT end in a slash.
|
||||
If you want the new request.base to include path info (not just the
|
||||
host), you must explicitly set base to the full base path, and ALSO
|
||||
set 'local' to '', so that the X-Forwarded-Host request header
|
||||
(which never includes path info) does not override it. Regardless,
|
||||
the value for 'base' MUST NOT end in a slash.
|
||||
|
||||
cherrypy.request.remote.ip (the IP address of the client) will be
|
||||
rewritten if the header specified by the 'remote' arg is valid.
|
||||
By default, 'remote' is set to 'X-Forwarded-For'. If you do not
|
||||
want to rewrite remote.ip, set the 'remote' arg to an empty string.
|
||||
rewritten if the header specified by the 'remote' arg is valid. By
|
||||
default, 'remote' is set to 'X-Forwarded-For'. If you do not want to
|
||||
rewrite remote.ip, set the 'remote' arg to an empty string.
|
||||
"""
|
||||
|
||||
request = cherrypy.serving.request
|
||||
@@ -217,8 +218,8 @@ def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For',
|
||||
def ignore_headers(headers=('Range',), debug=False):
|
||||
"""Delete request headers whose field names are included in 'headers'.
|
||||
|
||||
This is a useful tool for working behind certain HTTP servers;
|
||||
for example, Apache duplicates the work that CP does for 'Range'
|
||||
This is a useful tool for working behind certain HTTP servers; for
|
||||
example, Apache duplicates the work that CP does for 'Range'
|
||||
headers, and will doubly-truncate the response.
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
@@ -281,7 +282,6 @@ def referer(pattern, accept=True, accept_missing=False, error=403,
|
||||
|
||||
|
||||
class SessionAuth(object):
|
||||
|
||||
"""Assert that the user is logged in."""
|
||||
|
||||
session_key = 'username'
|
||||
@@ -319,7 +319,10 @@ Message: %(error_msg)s
|
||||
</body></html>""") % vars()).encode('utf-8')
|
||||
|
||||
def do_login(self, username, password, from_page='..', **kwargs):
|
||||
"""Login. May raise redirect, or return True if request handled."""
|
||||
"""Login.
|
||||
|
||||
May raise redirect, or return True if request handled.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
error_msg = self.check_username_and_password(username, password)
|
||||
if error_msg:
|
||||
@@ -336,7 +339,10 @@ Message: %(error_msg)s
|
||||
raise cherrypy.HTTPRedirect(from_page or '/')
|
||||
|
||||
def do_logout(self, from_page='..', **kwargs):
|
||||
"""Logout. May raise redirect, or return True if request handled."""
|
||||
"""Logout.
|
||||
|
||||
May raise redirect, or return True if request handled.
|
||||
"""
|
||||
sess = cherrypy.session
|
||||
username = sess.get(self.session_key)
|
||||
sess[self.session_key] = None
|
||||
@@ -346,7 +352,9 @@ Message: %(error_msg)s
|
||||
raise cherrypy.HTTPRedirect(from_page)
|
||||
|
||||
def do_check(self):
|
||||
"""Assert username. Raise redirect, or return True if request handled.
|
||||
"""Assert username.
|
||||
|
||||
Raise redirect, or return True if request handled.
|
||||
"""
|
||||
sess = cherrypy.session
|
||||
request = cherrypy.serving.request
|
||||
@@ -408,8 +416,7 @@ def session_auth(**kwargs):
|
||||
|
||||
Any attribute of the SessionAuth class may be overridden
|
||||
via a keyword arg to this function:
|
||||
|
||||
""" + '\n '.join(
|
||||
""" + '\n' + '\n '.join(
|
||||
'{!s}: {!s}'.format(k, type(getattr(SessionAuth, k)).__name__)
|
||||
for k in dir(SessionAuth)
|
||||
if not k.startswith('__')
|
||||
@@ -490,8 +497,8 @@ def trailing_slash(missing=True, extra=False, status=None, debug=False):
|
||||
def flatten(debug=False):
|
||||
"""Wrap response.body in a generator that recursively iterates over body.
|
||||
|
||||
This allows cherrypy.response.body to consist of 'nested generators';
|
||||
that is, a set of generators that yield generators.
|
||||
This allows cherrypy.response.body to consist of 'nested
|
||||
generators'; that is, a set of generators that yield generators.
|
||||
"""
|
||||
def flattener(input):
|
||||
numchunks = 0
|
||||
@@ -622,13 +629,15 @@ def autovary(ignore=None, debug=False):
|
||||
|
||||
|
||||
def convert_params(exception=ValueError, error=400):
|
||||
"""Convert request params based on function annotations, with error handling.
|
||||
"""Convert request params based on function annotations.
|
||||
|
||||
exception
|
||||
Exception class to catch.
|
||||
This function also processes errors that are subclasses of ``exception``.
|
||||
|
||||
status
|
||||
The HTTP error code to return to the client on failure.
|
||||
:param BaseException exception: Exception class to catch.
|
||||
:type exception: BaseException
|
||||
|
||||
:param error: The HTTP status code to return to the client on failure.
|
||||
:type error: int
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
types = request.handler.callable.__annotations__
|
||||
|
||||
@@ -261,9 +261,7 @@ class ResponseEncoder:
|
||||
|
||||
|
||||
def prepare_iter(value):
|
||||
"""
|
||||
Ensure response body is iterable and resolves to False when empty.
|
||||
"""
|
||||
"""Ensure response body is iterable and resolves to False when empty."""
|
||||
if isinstance(value, text_or_bytes):
|
||||
# strings get wrapped in a list because iterating over a single
|
||||
# item list is much faster than iterating over every character
|
||||
@@ -360,7 +358,6 @@ def gzip(compress_level=5, mime_types=['text/html', 'text/plain'],
|
||||
* No 'gzip' or 'x-gzip' is present in the Accept-Encoding header
|
||||
* No 'gzip' or 'x-gzip' with a qvalue > 0 is present
|
||||
* The 'identity' value is given with a qvalue > 0.
|
||||
|
||||
"""
|
||||
request = cherrypy.serving.request
|
||||
response = cherrypy.serving.response
|
||||
|
||||
@@ -14,7 +14,6 @@ from cherrypy.process.plugins import SimplePlugin
|
||||
|
||||
|
||||
class ReferrerTree(object):
|
||||
|
||||
"""An object which gathers all referrers of an object to a given depth."""
|
||||
|
||||
peek_length = 40
|
||||
@@ -132,7 +131,6 @@ def get_context(obj):
|
||||
|
||||
|
||||
class GCRoot(object):
|
||||
|
||||
"""A CherryPy page handler for testing reference leaks."""
|
||||
|
||||
classes = [
|
||||
|
||||
@@ -71,10 +71,10 @@ def protocol_from_http(protocol_str):
|
||||
def get_ranges(headervalue, content_length):
|
||||
"""Return a list of (start, stop) indices from a Range header, or None.
|
||||
|
||||
Each (start, stop) tuple will be composed of two ints, which are suitable
|
||||
for use in a slicing operation. That is, the header "Range: bytes=3-6",
|
||||
if applied against a Python string, is requesting resource[3:7]. This
|
||||
function will return the list [(3, 7)].
|
||||
Each (start, stop) tuple will be composed of two ints, which are
|
||||
suitable for use in a slicing operation. That is, the header "Range:
|
||||
bytes=3-6", if applied against a Python string, is requesting
|
||||
resource[3:7]. This function will return the list [(3, 7)].
|
||||
|
||||
If this function returns an empty list, you should return HTTP 416.
|
||||
"""
|
||||
@@ -127,7 +127,6 @@ def get_ranges(headervalue, content_length):
|
||||
|
||||
|
||||
class HeaderElement(object):
|
||||
|
||||
"""An element (with parameters) from an HTTP header's element list."""
|
||||
|
||||
def __init__(self, value, params=None):
|
||||
@@ -169,14 +168,14 @@ q_separator = re.compile(r'; *q *=')
|
||||
|
||||
|
||||
class AcceptElement(HeaderElement):
|
||||
|
||||
"""An element (with parameters) from an Accept* header's element list.
|
||||
|
||||
AcceptElement objects are comparable; the more-preferred object will be
|
||||
"less than" the less-preferred object. They are also therefore sortable;
|
||||
if you sort a list of AcceptElement objects, they will be listed in
|
||||
priority order; the most preferred value will be first. Yes, it should
|
||||
have been the other way around, but it's too late to fix now.
|
||||
AcceptElement objects are comparable; the more-preferred object will
|
||||
be "less than" the less-preferred object. They are also therefore
|
||||
sortable; if you sort a list of AcceptElement objects, they will be
|
||||
listed in priority order; the most preferred value will be first.
|
||||
Yes, it should have been the other way around, but it's too late to
|
||||
fix now.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@@ -249,8 +248,7 @@ def header_elements(fieldname, fieldvalue):
|
||||
|
||||
|
||||
def decode_TEXT(value):
|
||||
r"""
|
||||
Decode :rfc:`2047` TEXT
|
||||
r"""Decode :rfc:`2047` TEXT.
|
||||
|
||||
>>> decode_TEXT("=?utf-8?q?f=C3=BCr?=") == b'f\xfcr'.decode('latin-1')
|
||||
True
|
||||
@@ -265,9 +263,7 @@ def decode_TEXT(value):
|
||||
|
||||
|
||||
def decode_TEXT_maybe(value):
|
||||
"""
|
||||
Decode the text but only if '=?' appears in it.
|
||||
"""
|
||||
"""Decode the text but only if '=?' appears in it."""
|
||||
return decode_TEXT(value) if '=?' in value else value
|
||||
|
||||
|
||||
@@ -388,7 +384,6 @@ def parse_query_string(query_string, keep_blank_values=True, encoding='utf-8'):
|
||||
|
||||
|
||||
class CaseInsensitiveDict(jaraco.collections.KeyTransformingDict):
|
||||
|
||||
"""A case-insensitive dict subclass.
|
||||
|
||||
Each key is changed on entry to title case.
|
||||
@@ -417,7 +412,6 @@ else:
|
||||
|
||||
|
||||
class HeaderMap(CaseInsensitiveDict):
|
||||
|
||||
"""A dict subclass for HTTP request and response headers.
|
||||
|
||||
Each key is changed on entry to str(key).title(). This allows headers
|
||||
@@ -494,7 +488,6 @@ class HeaderMap(CaseInsensitiveDict):
|
||||
|
||||
|
||||
class Host(object):
|
||||
|
||||
"""An internet address.
|
||||
|
||||
name
|
||||
|
||||
@@ -7,22 +7,22 @@ class NeverExpires(object):
|
||||
|
||||
|
||||
class Timer(object):
|
||||
"""
|
||||
A simple timer that will indicate when an expiration time has passed.
|
||||
"""
|
||||
"""A simple timer that will indicate when an expiration time has passed."""
|
||||
def __init__(self, expiration):
|
||||
'Create a timer that expires at `expiration` (UTC datetime)'
|
||||
self.expiration = expiration
|
||||
|
||||
@classmethod
|
||||
def after(cls, elapsed):
|
||||
"""
|
||||
Return a timer that will expire after `elapsed` passes.
|
||||
"""
|
||||
return cls(datetime.datetime.utcnow() + elapsed)
|
||||
"""Return a timer that will expire after `elapsed` passes."""
|
||||
return cls(
|
||||
datetime.datetime.now(datetime.timezone.utc) + elapsed,
|
||||
)
|
||||
|
||||
def expired(self):
|
||||
return datetime.datetime.utcnow() >= self.expiration
|
||||
return datetime.datetime.now(
|
||||
datetime.timezone.utc,
|
||||
) >= self.expiration
|
||||
|
||||
|
||||
class LockTimeout(Exception):
|
||||
@@ -30,9 +30,7 @@ class LockTimeout(Exception):
|
||||
|
||||
|
||||
class LockChecker(object):
|
||||
"""
|
||||
Keep track of the time and detect if a timeout has expired
|
||||
"""
|
||||
"""Keep track of the time and detect if a timeout has expired."""
|
||||
def __init__(self, session_id, timeout):
|
||||
self.session_id = session_id
|
||||
if timeout:
|
||||
|
||||
@@ -30,7 +30,6 @@ to get a quick sanity-check on overall CP performance. Use the
|
||||
``--profile`` flag when running the test suite. Then, use the ``serve()``
|
||||
function to browse the results in a web browser. If you run this
|
||||
module from the command line, it will call ``serve()`` for you.
|
||||
|
||||
"""
|
||||
|
||||
import io
|
||||
@@ -47,7 +46,9 @@ try:
|
||||
import pstats
|
||||
|
||||
def new_func_strip_path(func_name):
|
||||
"""Make profiler output more readable by adding `__init__` modules' parents
|
||||
"""Add ``__init__`` modules' parents.
|
||||
|
||||
This makes the profiler output more readable.
|
||||
"""
|
||||
filename, line, name = func_name
|
||||
if filename.endswith('__init__.py'):
|
||||
|
||||
@@ -27,18 +27,17 @@ from cherrypy._cpcompat import text_or_bytes
|
||||
|
||||
|
||||
class NamespaceSet(dict):
|
||||
|
||||
"""A dict of config namespace names and handlers.
|
||||
|
||||
Each config entry should begin with a namespace name; the corresponding
|
||||
namespace handler will be called once for each config entry in that
|
||||
namespace, and will be passed two arguments: the config key (with the
|
||||
namespace removed) and the config value.
|
||||
Each config entry should begin with a namespace name; the
|
||||
corresponding namespace handler will be called once for each config
|
||||
entry in that namespace, and will be passed two arguments: the
|
||||
config key (with the namespace removed) and the config value.
|
||||
|
||||
Namespace handlers may be any Python callable; they may also be
|
||||
context managers, in which case their __enter__
|
||||
method should return a callable to be used as the handler.
|
||||
See cherrypy.tools (the Toolbox class) for an example.
|
||||
context managers, in which case their __enter__ method should return
|
||||
a callable to be used as the handler. See cherrypy.tools (the
|
||||
Toolbox class) for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, config):
|
||||
@@ -48,9 +47,10 @@ class NamespaceSet(dict):
|
||||
A flat dict, where keys use dots to separate
|
||||
namespaces, and values are arbitrary.
|
||||
|
||||
The first name in each config key is used to look up the corresponding
|
||||
namespace handler. For example, a config entry of {'tools.gzip.on': v}
|
||||
will call the 'tools' namespace handler with the args: ('gzip.on', v)
|
||||
The first name in each config key is used to look up the
|
||||
corresponding namespace handler. For example, a config entry of
|
||||
{'tools.gzip.on': v} will call the 'tools' namespace handler
|
||||
with the args: ('gzip.on', v)
|
||||
"""
|
||||
# Separate the given config into namespaces
|
||||
ns_confs = {}
|
||||
@@ -103,7 +103,6 @@ class NamespaceSet(dict):
|
||||
|
||||
|
||||
class Config(dict):
|
||||
|
||||
"""A dict-like set of configuration data, with defaults and namespaces.
|
||||
|
||||
May take a file, filename, or dict.
|
||||
@@ -167,7 +166,7 @@ class Parser(configparser.ConfigParser):
|
||||
self._read(fp, filename)
|
||||
|
||||
def as_dict(self, raw=False, vars=None):
|
||||
"""Convert an INI file to a dictionary"""
|
||||
"""Convert an INI file to a dictionary."""
|
||||
# Load INI file into a dict
|
||||
result = {}
|
||||
for section in self.sections():
|
||||
@@ -188,7 +187,7 @@ class Parser(configparser.ConfigParser):
|
||||
|
||||
def dict_from_file(self, file):
|
||||
if hasattr(file, 'read'):
|
||||
self.readfp(file)
|
||||
self.read_file(file)
|
||||
else:
|
||||
self.read(file)
|
||||
return self.as_dict()
|
||||
|
||||
@@ -120,7 +120,6 @@ missing = object()
|
||||
|
||||
|
||||
class Session(object):
|
||||
|
||||
"""A CherryPy dict-like Session object (one per request)."""
|
||||
|
||||
_id = None
|
||||
@@ -148,9 +147,11 @@ class Session(object):
|
||||
to session data."""
|
||||
|
||||
loaded = False
|
||||
"""If True, data has been retrieved from storage.
|
||||
|
||||
This should happen automatically on the first attempt to access
|
||||
session data.
|
||||
"""
|
||||
If True, data has been retrieved from storage. This should happen
|
||||
automatically on the first attempt to access session data."""
|
||||
|
||||
clean_thread = None
|
||||
'Class-level Monitor which calls self.clean_up.'
|
||||
@@ -165,9 +166,10 @@ class Session(object):
|
||||
'True if the session requested by the client did not exist.'
|
||||
|
||||
regenerated = False
|
||||
"""True if the application called session.regenerate().
|
||||
|
||||
This is not set by internal calls to regenerate the session id.
|
||||
"""
|
||||
True if the application called session.regenerate(). This is not set by
|
||||
internal calls to regenerate the session id."""
|
||||
|
||||
debug = False
|
||||
'If True, log debug information.'
|
||||
@@ -335,8 +337,9 @@ class Session(object):
|
||||
|
||||
def pop(self, key, default=missing):
|
||||
"""Remove the specified key and return the corresponding value.
|
||||
If key is not found, default is returned if given,
|
||||
otherwise KeyError is raised.
|
||||
|
||||
If key is not found, default is returned if given, otherwise
|
||||
KeyError is raised.
|
||||
"""
|
||||
if not self.loaded:
|
||||
self.load()
|
||||
@@ -351,13 +354,19 @@ class Session(object):
|
||||
return key in self._data
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
|
||||
"""D.get(k[,d]) -> D[k] if k in D, else d.
|
||||
|
||||
d defaults to None.
|
||||
"""
|
||||
if not self.loaded:
|
||||
self.load()
|
||||
return self._data.get(key, default)
|
||||
|
||||
def update(self, d):
|
||||
"""D.update(E) -> None. Update D from E: for k in E: D[k] = E[k]."""
|
||||
"""D.update(E) -> None.
|
||||
|
||||
Update D from E: for k in E: D[k] = E[k].
|
||||
"""
|
||||
if not self.loaded:
|
||||
self.load()
|
||||
self._data.update(d)
|
||||
@@ -369,7 +378,10 @@ class Session(object):
|
||||
return self._data.setdefault(key, default)
|
||||
|
||||
def clear(self):
|
||||
"""D.clear() -> None. Remove all items from D."""
|
||||
"""D.clear() -> None.
|
||||
|
||||
Remove all items from D.
|
||||
"""
|
||||
if not self.loaded:
|
||||
self.load()
|
||||
self._data.clear()
|
||||
@@ -492,7 +504,8 @@ class FileSession(Session):
|
||||
"""Set up the storage system for file-based sessions.
|
||||
|
||||
This should only be called once per process; this will be done
|
||||
automatically when using sessions.init (as the built-in Tool does).
|
||||
automatically when using sessions.init (as the built-in Tool
|
||||
does).
|
||||
"""
|
||||
# The 'storage_path' arg is required for file-based sessions.
|
||||
kwargs['storage_path'] = os.path.abspath(kwargs['storage_path'])
|
||||
@@ -616,7 +629,8 @@ class MemcachedSession(Session):
|
||||
"""Set up the storage system for memcached-based sessions.
|
||||
|
||||
This should only be called once per process; this will be done
|
||||
automatically when using sessions.init (as the built-in Tool does).
|
||||
automatically when using sessions.init (as the built-in Tool
|
||||
does).
|
||||
"""
|
||||
for k, v in kwargs.items():
|
||||
setattr(cls, k, v)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""Module with helpers for serving static files."""
|
||||
|
||||
import mimetypes
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import stat
|
||||
import mimetypes
|
||||
import urllib.parse
|
||||
import unicodedata
|
||||
|
||||
import urllib.parse
|
||||
from email.generator import _make_boundary as make_boundary
|
||||
from io import UnsupportedOperation
|
||||
|
||||
import cherrypy
|
||||
from cherrypy._cpcompat import ntob
|
||||
from cherrypy.lib import cptools, httputil, file_generator_limited
|
||||
from cherrypy.lib import cptools, file_generator_limited, httputil
|
||||
|
||||
|
||||
def _setup_mimetypes():
|
||||
@@ -57,15 +56,15 @@ def serve_file(path, content_type=None, disposition=None, name=None,
|
||||
debug=False):
|
||||
"""Set status, headers, and body in order to serve the given path.
|
||||
|
||||
The Content-Type header will be set to the content_type arg, if provided.
|
||||
If not provided, the Content-Type will be guessed by the file extension
|
||||
of the 'path' argument.
|
||||
The Content-Type header will be set to the content_type arg, if
|
||||
provided. If not provided, the Content-Type will be guessed by the
|
||||
file extension of the 'path' argument.
|
||||
|
||||
If disposition is not None, the Content-Disposition header will be set
|
||||
to "<disposition>; filename=<name>; filename*=utf-8''<name>"
|
||||
as described in :rfc:`6266#appendix-D`.
|
||||
If name is None, it will be set to the basename of path.
|
||||
If disposition is None, no Content-Disposition header will be written.
|
||||
If disposition is not None, the Content-Disposition header will be
|
||||
set to "<disposition>; filename=<name>; filename*=utf-8''<name>" as
|
||||
described in :rfc:`6266#appendix-D`. If name is None, it will be set
|
||||
to the basename of path. If disposition is None, no Content-
|
||||
Disposition header will be written.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
|
||||
@@ -185,7 +184,10 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None,
|
||||
|
||||
|
||||
def _serve_fileobj(fileobj, content_type, content_length, debug=False):
|
||||
"""Internal. Set response.body to the given file object, perhaps ranged."""
|
||||
"""Set ``response.body`` to the given file object, perhaps ranged.
|
||||
|
||||
Internal helper.
|
||||
"""
|
||||
response = cherrypy.serving.response
|
||||
|
||||
# HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code
|
||||
|
||||
@@ -31,7 +31,6 @@ _module__file__base = os.getcwd()
|
||||
|
||||
|
||||
class SimplePlugin(object):
|
||||
|
||||
"""Plugin base class which auto-subscribes methods for known channels."""
|
||||
|
||||
bus = None
|
||||
@@ -59,7 +58,6 @@ class SimplePlugin(object):
|
||||
|
||||
|
||||
class SignalHandler(object):
|
||||
|
||||
"""Register bus channels (and listeners) for system signals.
|
||||
|
||||
You can modify what signals your application listens for, and what it does
|
||||
@@ -171,8 +169,8 @@ class SignalHandler(object):
|
||||
If the optional 'listener' argument is provided, it will be
|
||||
subscribed as a listener for the given signal's channel.
|
||||
|
||||
If the given signal name or number is not available on the current
|
||||
platform, ValueError is raised.
|
||||
If the given signal name or number is not available on the
|
||||
current platform, ValueError is raised.
|
||||
"""
|
||||
if isinstance(signal, text_or_bytes):
|
||||
signum = getattr(_signal, signal, None)
|
||||
@@ -218,11 +216,10 @@ except ImportError:
|
||||
|
||||
|
||||
class DropPrivileges(SimplePlugin):
|
||||
|
||||
"""Drop privileges. uid/gid arguments not available on Windows.
|
||||
|
||||
Special thanks to `Gavin Baker
|
||||
<http://antonym.org/2005/12/dropping-privileges-in-python.html>`_
|
||||
<http://antonym.org/2005/12/dropping-privileges-in-python.html>`_.
|
||||
"""
|
||||
|
||||
def __init__(self, bus, umask=None, uid=None, gid=None):
|
||||
@@ -234,7 +231,10 @@ class DropPrivileges(SimplePlugin):
|
||||
|
||||
@property
|
||||
def uid(self):
|
||||
"""The uid under which to run. Availability: Unix."""
|
||||
"""The uid under which to run.
|
||||
|
||||
Availability: Unix.
|
||||
"""
|
||||
return self._uid
|
||||
|
||||
@uid.setter
|
||||
@@ -250,7 +250,10 @@ class DropPrivileges(SimplePlugin):
|
||||
|
||||
@property
|
||||
def gid(self):
|
||||
"""The gid under which to run. Availability: Unix."""
|
||||
"""The gid under which to run.
|
||||
|
||||
Availability: Unix.
|
||||
"""
|
||||
return self._gid
|
||||
|
||||
@gid.setter
|
||||
@@ -332,7 +335,6 @@ class DropPrivileges(SimplePlugin):
|
||||
|
||||
|
||||
class Daemonizer(SimplePlugin):
|
||||
|
||||
"""Daemonize the running script.
|
||||
|
||||
Use this with a Web Site Process Bus via::
|
||||
@@ -423,7 +425,6 @@ class Daemonizer(SimplePlugin):
|
||||
|
||||
|
||||
class PIDFile(SimplePlugin):
|
||||
|
||||
"""Maintain a PID file via a WSPBus."""
|
||||
|
||||
def __init__(self, bus, pidfile):
|
||||
@@ -453,12 +454,11 @@ class PIDFile(SimplePlugin):
|
||||
|
||||
|
||||
class PerpetualTimer(threading.Timer):
|
||||
|
||||
"""A responsive subclass of threading.Timer whose run() method repeats.
|
||||
|
||||
Use this timer only when you really need a very interruptible timer;
|
||||
this checks its 'finished' condition up to 20 times a second, which can
|
||||
results in pretty high CPU usage
|
||||
this checks its 'finished' condition up to 20 times a second, which
|
||||
can results in pretty high CPU usage
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -483,14 +483,14 @@ class PerpetualTimer(threading.Timer):
|
||||
|
||||
|
||||
class BackgroundTask(threading.Thread):
|
||||
|
||||
"""A subclass of threading.Thread whose run() method repeats.
|
||||
|
||||
Use this class for most repeating tasks. It uses time.sleep() to wait
|
||||
for each interval, which isn't very responsive; that is, even if you call
|
||||
self.cancel(), you'll have to wait until the sleep() call finishes before
|
||||
the thread stops. To compensate, it defaults to being daemonic, which means
|
||||
it won't delay stopping the whole process.
|
||||
Use this class for most repeating tasks. It uses time.sleep() to
|
||||
wait for each interval, which isn't very responsive; that is, even
|
||||
if you call self.cancel(), you'll have to wait until the sleep()
|
||||
call finishes before the thread stops. To compensate, it defaults to
|
||||
being daemonic, which means it won't delay stopping the whole
|
||||
process.
|
||||
"""
|
||||
|
||||
def __init__(self, interval, function, args=[], kwargs={}, bus=None):
|
||||
@@ -525,7 +525,6 @@ class BackgroundTask(threading.Thread):
|
||||
|
||||
|
||||
class Monitor(SimplePlugin):
|
||||
|
||||
"""WSPBus listener to periodically run a callback in its own thread."""
|
||||
|
||||
callback = None
|
||||
@@ -582,7 +581,6 @@ class Monitor(SimplePlugin):
|
||||
|
||||
|
||||
class Autoreloader(Monitor):
|
||||
|
||||
"""Monitor which re-executes the process when files change.
|
||||
|
||||
This :ref:`plugin<plugins>` restarts the process (via :func:`os.execv`)
|
||||
@@ -699,20 +697,20 @@ class Autoreloader(Monitor):
|
||||
|
||||
|
||||
class ThreadManager(SimplePlugin):
|
||||
|
||||
"""Manager for HTTP request threads.
|
||||
|
||||
If you have control over thread creation and destruction, publish to
|
||||
the 'acquire_thread' and 'release_thread' channels (for each thread).
|
||||
This will register/unregister the current thread and publish to
|
||||
'start_thread' and 'stop_thread' listeners in the bus as needed.
|
||||
the 'acquire_thread' and 'release_thread' channels (for each
|
||||
thread). This will register/unregister the current thread and
|
||||
publish to 'start_thread' and 'stop_thread' listeners in the bus as
|
||||
needed.
|
||||
|
||||
If threads are created and destroyed by code you do not control
|
||||
(e.g., Apache), then, at the beginning of every HTTP request,
|
||||
publish to 'acquire_thread' only. You should not publish to
|
||||
'release_thread' in this case, since you do not know whether
|
||||
the thread will be re-used or not. The bus will call
|
||||
'stop_thread' listeners for you when it stops.
|
||||
'release_thread' in this case, since you do not know whether the
|
||||
thread will be re-used or not. The bus will call 'stop_thread'
|
||||
listeners for you when it stops.
|
||||
"""
|
||||
|
||||
threads = None
|
||||
|
||||
@@ -132,7 +132,6 @@ class Timeouts:
|
||||
|
||||
|
||||
class ServerAdapter(object):
|
||||
|
||||
"""Adapter for an HTTP server.
|
||||
|
||||
If you need to start more than one HTTP server (to serve on multiple
|
||||
@@ -188,9 +187,7 @@ class ServerAdapter(object):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
"""
|
||||
A description about where this server is bound.
|
||||
"""
|
||||
"""A description about where this server is bound."""
|
||||
if self.bind_addr is None:
|
||||
on_what = 'unknown interface (dynamic?)'
|
||||
elif isinstance(self.bind_addr, tuple):
|
||||
@@ -292,7 +289,6 @@ class ServerAdapter(object):
|
||||
|
||||
|
||||
class FlupCGIServer(object):
|
||||
|
||||
"""Adapter for a flup.server.cgi.WSGIServer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -316,7 +312,6 @@ class FlupCGIServer(object):
|
||||
|
||||
|
||||
class FlupFCGIServer(object):
|
||||
|
||||
"""Adapter for a flup.server.fcgi.WSGIServer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -362,7 +357,6 @@ class FlupFCGIServer(object):
|
||||
|
||||
|
||||
class FlupSCGIServer(object):
|
||||
|
||||
"""Adapter for a flup.server.scgi.WSGIServer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
"""Windows service. Requires pywin32."""
|
||||
"""Windows service.
|
||||
|
||||
Requires pywin32.
|
||||
"""
|
||||
|
||||
import os
|
||||
import win32api
|
||||
@@ -11,7 +14,6 @@ from cherrypy.process import wspbus, plugins
|
||||
|
||||
|
||||
class ConsoleCtrlHandler(plugins.SimplePlugin):
|
||||
|
||||
"""A WSPBus plugin for handling Win32 console events (like Ctrl-C)."""
|
||||
|
||||
def __init__(self, bus):
|
||||
@@ -69,10 +71,10 @@ class ConsoleCtrlHandler(plugins.SimplePlugin):
|
||||
|
||||
|
||||
class Win32Bus(wspbus.Bus):
|
||||
|
||||
"""A Web Site Process Bus implementation for Win32.
|
||||
|
||||
Instead of time.sleep, this bus blocks using native win32event objects.
|
||||
Instead of time.sleep, this bus blocks using native win32event
|
||||
objects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -120,7 +122,6 @@ class Win32Bus(wspbus.Bus):
|
||||
|
||||
|
||||
class _ControlCodes(dict):
|
||||
|
||||
"""Control codes used to "signal" a service via ControlService.
|
||||
|
||||
User-defined control codes are in the range 128-255. We generally use
|
||||
@@ -152,7 +153,6 @@ def signal_child(service, command):
|
||||
|
||||
|
||||
class PyWebService(win32serviceutil.ServiceFramework):
|
||||
|
||||
"""Python Web Service."""
|
||||
|
||||
_svc_name_ = 'Python Web Service'
|
||||
|
||||
@@ -57,7 +57,6 @@ the new state.::
|
||||
| \ |
|
||||
| V V
|
||||
STARTED <-- STARTING
|
||||
|
||||
"""
|
||||
|
||||
import atexit
|
||||
@@ -65,7 +64,7 @@ import atexit
|
||||
try:
|
||||
import ctypes
|
||||
except ImportError:
|
||||
"""Google AppEngine is shipped without ctypes
|
||||
"""Google AppEngine is shipped without ctypes.
|
||||
|
||||
:seealso: http://stackoverflow.com/a/6523777/70170
|
||||
"""
|
||||
@@ -165,8 +164,8 @@ class Bus(object):
|
||||
All listeners for a given channel are guaranteed to be called even
|
||||
if others at the same channel fail. Each failure is logged, but
|
||||
execution proceeds on to the next listener. The only way to stop all
|
||||
processing from inside a listener is to raise SystemExit and stop the
|
||||
whole server.
|
||||
processing from inside a listener is to raise SystemExit and stop
|
||||
the whole server.
|
||||
"""
|
||||
|
||||
states = states
|
||||
@@ -312,8 +311,9 @@ class Bus(object):
|
||||
def restart(self):
|
||||
"""Restart the process (may close connections).
|
||||
|
||||
This method does not restart the process from the calling thread;
|
||||
instead, it stops the bus and asks the main thread to call execv.
|
||||
This method does not restart the process from the calling
|
||||
thread; instead, it stops the bus and asks the main thread to
|
||||
call execv.
|
||||
"""
|
||||
self.execv = True
|
||||
self.exit()
|
||||
@@ -327,10 +327,11 @@ class Bus(object):
|
||||
"""Wait for the EXITING state, KeyboardInterrupt or SystemExit.
|
||||
|
||||
This function is intended to be called only by the main thread.
|
||||
After waiting for the EXITING state, it also waits for all threads
|
||||
to terminate, and then calls os.execv if self.execv is True. This
|
||||
design allows another thread to call bus.restart, yet have the main
|
||||
thread perform the actual execv call (required on some platforms).
|
||||
After waiting for the EXITING state, it also waits for all
|
||||
threads to terminate, and then calls os.execv if self.execv is
|
||||
True. This design allows another thread to call bus.restart, yet
|
||||
have the main thread perform the actual execv call (required on
|
||||
some platforms).
|
||||
"""
|
||||
try:
|
||||
self.wait(states.EXITING, interval=interval, channel='main')
|
||||
@@ -379,13 +380,14 @@ class Bus(object):
|
||||
def _do_execv(self):
|
||||
"""Re-execute the current process.
|
||||
|
||||
This must be called from the main thread, because certain platforms
|
||||
(OS X) don't allow execv to be called in a child thread very well.
|
||||
This must be called from the main thread, because certain
|
||||
platforms (OS X) don't allow execv to be called in a child
|
||||
thread very well.
|
||||
"""
|
||||
try:
|
||||
args = self._get_true_argv()
|
||||
except NotImplementedError:
|
||||
"""It's probably win32 or GAE"""
|
||||
"""It's probably win32 or GAE."""
|
||||
args = [sys.executable] + self._get_interpreter_argv() + sys.argv
|
||||
|
||||
self.log('Re-spawning %s' % ' '.join(args))
|
||||
@@ -472,7 +474,7 @@ class Bus(object):
|
||||
c_ind = None
|
||||
|
||||
if is_module:
|
||||
"""It's containing `-m -m` sequence of arguments"""
|
||||
"""It's containing `-m -m` sequence of arguments."""
|
||||
if is_command and c_ind < m_ind:
|
||||
"""There's `-c -c` before `-m`"""
|
||||
raise RuntimeError(
|
||||
@@ -481,7 +483,7 @@ class Bus(object):
|
||||
# Survive module argument here
|
||||
original_module = sys.argv[0]
|
||||
if not os.access(original_module, os.R_OK):
|
||||
"""There's no such module exist"""
|
||||
"""There's no such module exist."""
|
||||
raise AttributeError(
|
||||
"{} doesn't seem to be a module "
|
||||
'accessible by current user'.format(original_module))
|
||||
@@ -489,12 +491,12 @@ class Bus(object):
|
||||
# ... and substitute it with the original module path:
|
||||
_argv.insert(m_ind, original_module)
|
||||
elif is_command:
|
||||
"""It's containing just `-c -c` sequence of arguments"""
|
||||
"""It's containing just `-c -c` sequence of arguments."""
|
||||
raise RuntimeError(
|
||||
"Cannot reconstruct command from '-c'. "
|
||||
'Ref: https://github.com/cherrypy/cherrypy/issues/1545')
|
||||
except AttributeError:
|
||||
"""It looks Py_GetArgcArgv is completely absent in some environments
|
||||
"""It looks Py_GetArgcArgv's completely absent in some environments
|
||||
|
||||
It is known, that there's no Py_GetArgcArgv in MS Windows and
|
||||
``ctypes`` module is completely absent in Google AppEngine
|
||||
@@ -512,13 +514,13 @@ class Bus(object):
|
||||
"""Prepend current working dir to PATH environment variable if needed.
|
||||
|
||||
If sys.path[0] is an empty string, the interpreter was likely
|
||||
invoked with -m and the effective path is about to change on
|
||||
re-exec. Add the current directory to $PYTHONPATH to ensure
|
||||
that the new process sees the same path.
|
||||
invoked with -m and the effective path is about to change on re-
|
||||
exec. Add the current directory to $PYTHONPATH to ensure that
|
||||
the new process sees the same path.
|
||||
|
||||
This issue cannot be addressed in the general case because
|
||||
Python cannot reliably reconstruct the
|
||||
original command line (http://bugs.python.org/issue14208).
|
||||
Python cannot reliably reconstruct the original command line (
|
||||
http://bugs.python.org/issue14208).
|
||||
|
||||
(This idea filched from tornado.autoreload)
|
||||
"""
|
||||
@@ -536,10 +538,10 @@ class Bus(object):
|
||||
"""Set the CLOEXEC flag on all open files (except stdin/out/err).
|
||||
|
||||
If self.max_cloexec_files is an integer (the default), then on
|
||||
platforms which support it, it represents the max open files setting
|
||||
for the operating system. This function will be called just before
|
||||
the process is restarted via os.execv() to prevent open files
|
||||
from persisting into the new process.
|
||||
platforms which support it, it represents the max open files
|
||||
setting for the operating system. This function will be called
|
||||
just before the process is restarted via os.execv() to prevent
|
||||
open files from persisting into the new process.
|
||||
|
||||
Set self.max_cloexec_files to 0 to disable this behavior.
|
||||
"""
|
||||
@@ -578,7 +580,10 @@ class Bus(object):
|
||||
return t
|
||||
|
||||
def log(self, msg='', level=20, traceback=False):
|
||||
"""Log the given message. Append the last traceback if requested."""
|
||||
"""Log the given message.
|
||||
|
||||
Append the last traceback if requested.
|
||||
"""
|
||||
if traceback:
|
||||
msg += '\n' + ''.join(_traceback.format_exception(*sys.exc_info()))
|
||||
self.publish('log', msg, level)
|
||||
|
||||
@@ -9,7 +9,6 @@ Even before any tweaking, this should serve a few demonstration pages.
|
||||
Change to this directory and run:
|
||||
|
||||
cherryd -c site.conf
|
||||
|
||||
"""
|
||||
|
||||
import cherrypy
|
||||
|
||||
542
lib/munkres.py
542
lib/munkres.py
@@ -1,8 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: iso-8859-1 -*-
|
||||
|
||||
# Documentation is intended to be processed by Epydoc.
|
||||
|
||||
"""
|
||||
Introduction
|
||||
============
|
||||
@@ -11,266 +6,10 @@ The Munkres module provides an implementation of the Munkres algorithm
|
||||
(also called the Hungarian algorithm or the Kuhn-Munkres algorithm),
|
||||
useful for solving the Assignment Problem.
|
||||
|
||||
Assignment Problem
|
||||
==================
|
||||
|
||||
Let *C* be an *n*\ x\ *n* matrix representing the costs of each of *n* workers
|
||||
to perform any of *n* jobs. The assignment problem is to assign jobs to
|
||||
workers in a way that minimizes the total cost. Since each worker can perform
|
||||
only one job and each job can be assigned to only one worker the assignments
|
||||
represent an independent set of the matrix *C*.
|
||||
|
||||
One way to generate the optimal set is to create all permutations of
|
||||
the indexes necessary to traverse the matrix so that no row and column
|
||||
are used more than once. For instance, given this matrix (expressed in
|
||||
Python)::
|
||||
|
||||
matrix = [[5, 9, 1],
|
||||
[10, 3, 2],
|
||||
[8, 7, 4]]
|
||||
|
||||
You could use this code to generate the traversal indexes::
|
||||
|
||||
def permute(a, results):
|
||||
if len(a) == 1:
|
||||
results.insert(len(results), a)
|
||||
|
||||
else:
|
||||
for i in range(0, len(a)):
|
||||
element = a[i]
|
||||
a_copy = [a[j] for j in range(0, len(a)) if j != i]
|
||||
subresults = []
|
||||
permute(a_copy, subresults)
|
||||
for subresult in subresults:
|
||||
result = [element] + subresult
|
||||
results.insert(len(results), result)
|
||||
|
||||
results = []
|
||||
permute(range(len(matrix)), results) # [0, 1, 2] for a 3x3 matrix
|
||||
|
||||
After the call to permute(), the results matrix would look like this::
|
||||
|
||||
[[0, 1, 2],
|
||||
[0, 2, 1],
|
||||
[1, 0, 2],
|
||||
[1, 2, 0],
|
||||
[2, 0, 1],
|
||||
[2, 1, 0]]
|
||||
|
||||
You could then use that index matrix to loop over the original cost matrix
|
||||
and calculate the smallest cost of the combinations::
|
||||
|
||||
n = len(matrix)
|
||||
minval = sys.maxsize
|
||||
for row in range(n):
|
||||
cost = 0
|
||||
for col in range(n):
|
||||
cost += matrix[row][col]
|
||||
minval = min(cost, minval)
|
||||
|
||||
print minval
|
||||
|
||||
While this approach works fine for small matrices, it does not scale. It
|
||||
executes in O(*n*!) time: Calculating the permutations for an *n*\ x\ *n*
|
||||
matrix requires *n*! operations. For a 12x12 matrix, that's 479,001,600
|
||||
traversals. Even if you could manage to perform each traversal in just one
|
||||
millisecond, it would still take more than 133 hours to perform the entire
|
||||
traversal. A 20x20 matrix would take 2,432,902,008,176,640,000 operations. At
|
||||
an optimistic millisecond per operation, that's more than 77 million years.
|
||||
|
||||
The Munkres algorithm runs in O(*n*\ ^3) time, rather than O(*n*!). This
|
||||
package provides an implementation of that algorithm.
|
||||
|
||||
This version is based on
|
||||
http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html.
|
||||
|
||||
This version was written for Python by Brian Clapper from the (Ada) algorithm
|
||||
at the above web site. (The ``Algorithm::Munkres`` Perl version, in CPAN, was
|
||||
clearly adapted from the same web site.)
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
Construct a Munkres object::
|
||||
|
||||
from munkres import Munkres
|
||||
|
||||
m = Munkres()
|
||||
|
||||
Then use it to compute the lowest cost assignment from a cost matrix. Here's
|
||||
a sample program::
|
||||
|
||||
from munkres import Munkres, print_matrix
|
||||
|
||||
matrix = [[5, 9, 1],
|
||||
[10, 3, 2],
|
||||
[8, 7, 4]]
|
||||
m = Munkres()
|
||||
indexes = m.compute(matrix)
|
||||
print_matrix(matrix, msg='Lowest cost through this matrix:')
|
||||
total = 0
|
||||
for row, column in indexes:
|
||||
value = matrix[row][column]
|
||||
total += value
|
||||
print '(%d, %d) -> %d' % (row, column, value)
|
||||
print 'total cost: %d' % total
|
||||
|
||||
Running that program produces::
|
||||
|
||||
Lowest cost through this matrix:
|
||||
[5, 9, 1]
|
||||
[10, 3, 2]
|
||||
[8, 7, 4]
|
||||
(0, 0) -> 5
|
||||
(1, 1) -> 3
|
||||
(2, 2) -> 4
|
||||
total cost=12
|
||||
|
||||
The instantiated Munkres object can be used multiple times on different
|
||||
matrices.
|
||||
|
||||
Non-square Cost Matrices
|
||||
========================
|
||||
|
||||
The Munkres algorithm assumes that the cost matrix is square. However, it's
|
||||
possible to use a rectangular matrix if you first pad it with 0 values to make
|
||||
it square. This module automatically pads rectangular cost matrices to make
|
||||
them square.
|
||||
|
||||
Notes:
|
||||
|
||||
- The module operates on a *copy* of the caller's matrix, so any padding will
|
||||
not be seen by the caller.
|
||||
- The cost matrix must be rectangular or square. An irregular matrix will
|
||||
*not* work.
|
||||
|
||||
Calculating Profit, Rather than Cost
|
||||
====================================
|
||||
|
||||
The cost matrix is just that: A cost matrix. The Munkres algorithm finds
|
||||
the combination of elements (one from each row and column) that results in
|
||||
the smallest cost. It's also possible to use the algorithm to maximize
|
||||
profit. To do that, however, you have to convert your profit matrix to a
|
||||
cost matrix. The simplest way to do that is to subtract all elements from a
|
||||
large value. For example::
|
||||
|
||||
from munkres import Munkres, print_matrix
|
||||
|
||||
matrix = [[5, 9, 1],
|
||||
[10, 3, 2],
|
||||
[8, 7, 4]]
|
||||
cost_matrix = []
|
||||
for row in matrix:
|
||||
cost_row = []
|
||||
for col in row:
|
||||
cost_row += [sys.maxsize - col]
|
||||
cost_matrix += [cost_row]
|
||||
|
||||
m = Munkres()
|
||||
indexes = m.compute(cost_matrix)
|
||||
print_matrix(matrix, msg='Highest profit through this matrix:')
|
||||
total = 0
|
||||
for row, column in indexes:
|
||||
value = matrix[row][column]
|
||||
total += value
|
||||
print '(%d, %d) -> %d' % (row, column, value)
|
||||
|
||||
print 'total profit=%d' % total
|
||||
|
||||
Running that program produces::
|
||||
|
||||
Highest profit through this matrix:
|
||||
[5, 9, 1]
|
||||
[10, 3, 2]
|
||||
[8, 7, 4]
|
||||
(0, 1) -> 9
|
||||
(1, 0) -> 10
|
||||
(2, 2) -> 4
|
||||
total profit=23
|
||||
|
||||
The ``munkres`` module provides a convenience method for creating a cost
|
||||
matrix from a profit matrix. Since it doesn't know whether the matrix contains
|
||||
floating point numbers, decimals, or integers, you have to provide the
|
||||
conversion function; but the convenience method takes care of the actual
|
||||
creation of the cost matrix::
|
||||
|
||||
import munkres
|
||||
|
||||
cost_matrix = munkres.make_cost_matrix(matrix,
|
||||
lambda cost: sys.maxsize - cost)
|
||||
|
||||
So, the above profit-calculation program can be recast as::
|
||||
|
||||
from munkres import Munkres, print_matrix, make_cost_matrix
|
||||
|
||||
matrix = [[5, 9, 1],
|
||||
[10, 3, 2],
|
||||
[8, 7, 4]]
|
||||
cost_matrix = make_cost_matrix(matrix, lambda cost: sys.maxsize - cost)
|
||||
m = Munkres()
|
||||
indexes = m.compute(cost_matrix)
|
||||
print_matrix(matrix, msg='Lowest cost through this matrix:')
|
||||
total = 0
|
||||
for row, column in indexes:
|
||||
value = matrix[row][column]
|
||||
total += value
|
||||
print '(%d, %d) -> %d' % (row, column, value)
|
||||
print 'total profit=%d' % total
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html
|
||||
|
||||
2. Harold W. Kuhn. The Hungarian Method for the assignment problem.
|
||||
*Naval Research Logistics Quarterly*, 2:83-97, 1955.
|
||||
|
||||
3. Harold W. Kuhn. Variants of the Hungarian method for assignment
|
||||
problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956.
|
||||
|
||||
4. Munkres, J. Algorithms for the Assignment and Transportation Problems.
|
||||
*Journal of the Society of Industrial and Applied Mathematics*,
|
||||
5(1):32-38, March, 1957.
|
||||
|
||||
5. http://en.wikipedia.org/wiki/Hungarian_algorithm
|
||||
|
||||
Copyright and License
|
||||
=====================
|
||||
|
||||
This software is released under a BSD license, adapted from
|
||||
<http://opensource.org/licenses/bsd-license.php>
|
||||
|
||||
Copyright (c) 2008 Brian M. Clapper
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name "clapper.org" nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
||||
For complete usage documentation, see: https://software.clapper.org/munkres/
|
||||
"""
|
||||
|
||||
__docformat__ = 'restructuredtext'
|
||||
__docformat__ = 'markdown'
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Imports
|
||||
@@ -278,23 +17,43 @@ __docformat__ = 'restructuredtext'
|
||||
|
||||
import sys
|
||||
import copy
|
||||
from typing import Union, NewType, Sequence, Tuple, Optional, Callable
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exports
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
__all__ = ['Munkres', 'make_cost_matrix']
|
||||
__all__ = ['Munkres', 'make_cost_matrix', 'DISALLOWED']
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Globals
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AnyNum = NewType('AnyNum', Union[int, float])
|
||||
Matrix = NewType('Matrix', Sequence[Sequence[AnyNum]])
|
||||
|
||||
# Info about the module
|
||||
__version__ = "1.0.6"
|
||||
__version__ = "1.1.4"
|
||||
__author__ = "Brian Clapper, bmc@clapper.org"
|
||||
__url__ = "http://software.clapper.org/munkres/"
|
||||
__copyright__ = "(c) 2008 Brian M. Clapper"
|
||||
__license__ = "BSD-style license"
|
||||
__url__ = "https://software.clapper.org/munkres/"
|
||||
__copyright__ = "(c) 2008-2020 Brian M. Clapper"
|
||||
__license__ = "Apache Software License"
|
||||
|
||||
# Constants
|
||||
class DISALLOWED_OBJ(object):
|
||||
pass
|
||||
DISALLOWED = DISALLOWED_OBJ()
|
||||
DISALLOWED_PRINTVAL = "D"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class UnsolvableMatrix(Exception):
|
||||
"""
|
||||
Exception raised for unsolvable matrices
|
||||
"""
|
||||
pass
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Classes
|
||||
@@ -317,30 +76,18 @@ class Munkres:
|
||||
self.marked = None
|
||||
self.path = None
|
||||
|
||||
def make_cost_matrix(profit_matrix, inversion_function):
|
||||
"""
|
||||
**DEPRECATED**
|
||||
|
||||
Please use the module function ``make_cost_matrix()``.
|
||||
"""
|
||||
import munkres
|
||||
return munkres.make_cost_matrix(profit_matrix, inversion_function)
|
||||
|
||||
make_cost_matrix = staticmethod(make_cost_matrix)
|
||||
|
||||
def pad_matrix(self, matrix, pad_value=0):
|
||||
def pad_matrix(self, matrix: Matrix, pad_value: int=0) -> Matrix:
|
||||
"""
|
||||
Pad a possibly non-square matrix to make it square.
|
||||
|
||||
:Parameters:
|
||||
matrix : list of lists
|
||||
matrix to pad
|
||||
**Parameters**
|
||||
|
||||
pad_value : int
|
||||
value to use to pad the matrix
|
||||
- `matrix` (list of lists of numbers): matrix to pad
|
||||
- `pad_value` (`int`): value to use to pad the matrix
|
||||
|
||||
:rtype: list of lists
|
||||
:return: a new, possibly padded, matrix
|
||||
**Returns**
|
||||
|
||||
a new, possibly padded, matrix
|
||||
"""
|
||||
max_columns = 0
|
||||
total_rows = len(matrix)
|
||||
@@ -356,34 +103,35 @@ class Munkres:
|
||||
new_row = row[:]
|
||||
if total_rows > row_len:
|
||||
# Row too short. Pad it.
|
||||
new_row += [0] * (total_rows - row_len)
|
||||
new_row += [pad_value] * (total_rows - row_len)
|
||||
new_matrix += [new_row]
|
||||
|
||||
while len(new_matrix) < total_rows:
|
||||
new_matrix += [[0] * total_rows]
|
||||
new_matrix += [[pad_value] * total_rows]
|
||||
|
||||
return new_matrix
|
||||
|
||||
def compute(self, cost_matrix):
|
||||
def compute(self, cost_matrix: Matrix) -> Sequence[Tuple[int, int]]:
|
||||
"""
|
||||
Compute the indexes for the lowest-cost pairings between rows and
|
||||
columns in the database. Returns a list of (row, column) tuples
|
||||
columns in the database. Returns a list of `(row, column)` tuples
|
||||
that can be used to traverse the matrix.
|
||||
|
||||
:Parameters:
|
||||
cost_matrix : list of lists
|
||||
The cost matrix. If this cost matrix is not square, it
|
||||
will be padded with zeros, via a call to ``pad_matrix()``.
|
||||
(This method does *not* modify the caller's matrix. It
|
||||
operates on a copy of the matrix.)
|
||||
**WARNING**: This code handles square and rectangular matrices. It
|
||||
does *not* handle irregular matrices.
|
||||
|
||||
**WARNING**: This code handles square and rectangular
|
||||
matrices. It does *not* handle irregular matrices.
|
||||
**Parameters**
|
||||
|
||||
:rtype: list
|
||||
:return: A list of ``(row, column)`` tuples that describe the lowest
|
||||
cost path through the matrix
|
||||
- `cost_matrix` (list of lists of numbers): The cost matrix. If this
|
||||
cost matrix is not square, it will be padded with zeros, via a call
|
||||
to `pad_matrix()`. (This method does *not* modify the caller's
|
||||
matrix. It operates on a copy of the matrix.)
|
||||
|
||||
|
||||
**Returns**
|
||||
|
||||
A list of `(row, column)` tuples that describe the lowest cost path
|
||||
through the matrix
|
||||
"""
|
||||
self.C = self.pad_matrix(cost_matrix)
|
||||
self.n = len(self.C)
|
||||
@@ -422,18 +170,18 @@ class Munkres:
|
||||
|
||||
return results
|
||||
|
||||
def __copy_matrix(self, matrix):
|
||||
def __copy_matrix(self, matrix: Matrix) -> Matrix:
|
||||
"""Return an exact copy of the supplied matrix"""
|
||||
return copy.deepcopy(matrix)
|
||||
|
||||
def __make_matrix(self, n, val):
|
||||
def __make_matrix(self, n: int, val: AnyNum) -> Matrix:
|
||||
"""Create an *n*x*n* matrix, populating it with the specific value."""
|
||||
matrix = []
|
||||
for i in range(n):
|
||||
matrix += [[val for j in range(n)]]
|
||||
return matrix
|
||||
|
||||
def __step1(self):
|
||||
def __step1(self) -> int:
|
||||
"""
|
||||
For each row of the matrix, find the smallest element and
|
||||
subtract it from every element in its row. Go to Step 2.
|
||||
@@ -441,15 +189,22 @@ class Munkres:
|
||||
C = self.C
|
||||
n = self.n
|
||||
for i in range(n):
|
||||
minval = min(self.C[i])
|
||||
vals = [x for x in self.C[i] if x is not DISALLOWED]
|
||||
if len(vals) == 0:
|
||||
# All values in this row are DISALLOWED. This matrix is
|
||||
# unsolvable.
|
||||
raise UnsolvableMatrix(
|
||||
"Row {0} is entirely DISALLOWED.".format(i)
|
||||
)
|
||||
minval = min(vals)
|
||||
# Find the minimum value for this row and subtract that minimum
|
||||
# from every element in the row.
|
||||
for j in range(n):
|
||||
self.C[i][j] -= minval
|
||||
|
||||
if self.C[i][j] is not DISALLOWED:
|
||||
self.C[i][j] -= minval
|
||||
return 2
|
||||
|
||||
def __step2(self):
|
||||
def __step2(self) -> int:
|
||||
"""
|
||||
Find a zero (Z) in the resulting matrix. If there is no starred
|
||||
zero in its row or column, star Z. Repeat for each element in the
|
||||
@@ -464,11 +219,12 @@ class Munkres:
|
||||
self.marked[i][j] = 1
|
||||
self.col_covered[j] = True
|
||||
self.row_covered[i] = True
|
||||
break
|
||||
|
||||
self.__clear_covers()
|
||||
return 3
|
||||
|
||||
def __step3(self):
|
||||
def __step3(self) -> int:
|
||||
"""
|
||||
Cover each column containing a starred zero. If K columns are
|
||||
covered, the starred zeros describe a complete set of unique
|
||||
@@ -478,7 +234,7 @@ class Munkres:
|
||||
count = 0
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if self.marked[i][j] == 1:
|
||||
if self.marked[i][j] == 1 and not self.col_covered[j]:
|
||||
self.col_covered[j] = True
|
||||
count += 1
|
||||
|
||||
@@ -489,7 +245,7 @@ class Munkres:
|
||||
|
||||
return step
|
||||
|
||||
def __step4(self):
|
||||
def __step4(self) -> int:
|
||||
"""
|
||||
Find a noncovered zero and prime it. If there is no starred zero
|
||||
in the row containing this primed zero, Go to Step 5. Otherwise,
|
||||
@@ -499,11 +255,11 @@ class Munkres:
|
||||
"""
|
||||
step = 0
|
||||
done = False
|
||||
row = -1
|
||||
col = -1
|
||||
row = 0
|
||||
col = 0
|
||||
star_col = -1
|
||||
while not done:
|
||||
(row, col) = self.__find_a_zero()
|
||||
(row, col) = self.__find_a_zero(row, col)
|
||||
if row < 0:
|
||||
done = True
|
||||
step = 6
|
||||
@@ -522,7 +278,7 @@ class Munkres:
|
||||
|
||||
return step
|
||||
|
||||
def __step5(self):
|
||||
def __step5(self) -> int:
|
||||
"""
|
||||
Construct a series of alternating primed and starred zeros as
|
||||
follows. Let Z0 represent the uncovered primed zero found in Step 4.
|
||||
@@ -558,7 +314,7 @@ class Munkres:
|
||||
self.__erase_primes()
|
||||
return 3
|
||||
|
||||
def __step6(self):
|
||||
def __step6(self) -> int:
|
||||
"""
|
||||
Add the value found in Step 4 to every element of each covered
|
||||
row, and subtract it from every element of each uncovered column.
|
||||
@@ -566,34 +322,44 @@ class Munkres:
|
||||
lines.
|
||||
"""
|
||||
minval = self.__find_smallest()
|
||||
events = 0 # track actual changes to matrix
|
||||
for i in range(self.n):
|
||||
for j in range(self.n):
|
||||
if self.C[i][j] is DISALLOWED:
|
||||
continue
|
||||
if self.row_covered[i]:
|
||||
self.C[i][j] += minval
|
||||
events += 1
|
||||
if not self.col_covered[j]:
|
||||
self.C[i][j] -= minval
|
||||
events += 1
|
||||
if self.row_covered[i] and not self.col_covered[j]:
|
||||
events -= 2 # change reversed, no real difference
|
||||
if (events == 0):
|
||||
raise UnsolvableMatrix("Matrix cannot be solved!")
|
||||
return 4
|
||||
|
||||
def __find_smallest(self):
|
||||
def __find_smallest(self) -> AnyNum:
|
||||
"""Find the smallest uncovered value in the matrix."""
|
||||
minval = sys.maxsize
|
||||
for i in range(self.n):
|
||||
for j in range(self.n):
|
||||
if (not self.row_covered[i]) and (not self.col_covered[j]):
|
||||
if minval > self.C[i][j]:
|
||||
if self.C[i][j] is not DISALLOWED and minval > self.C[i][j]:
|
||||
minval = self.C[i][j]
|
||||
return minval
|
||||
|
||||
def __find_a_zero(self):
|
||||
|
||||
def __find_a_zero(self, i0: int = 0, j0: int = 0) -> Tuple[int, int]:
|
||||
"""Find the first uncovered element with value 0"""
|
||||
row = -1
|
||||
col = -1
|
||||
i = 0
|
||||
i = i0
|
||||
n = self.n
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
j = 0
|
||||
j = j0
|
||||
while True:
|
||||
if (self.C[i][j] == 0) and \
|
||||
(not self.row_covered[i]) and \
|
||||
@@ -601,16 +367,16 @@ class Munkres:
|
||||
row = i
|
||||
col = j
|
||||
done = True
|
||||
j += 1
|
||||
if j >= n:
|
||||
j = (j + 1) % n
|
||||
if j == j0:
|
||||
break
|
||||
i += 1
|
||||
if i >= n:
|
||||
i = (i + 1) % n
|
||||
if i == i0:
|
||||
done = True
|
||||
|
||||
return (row, col)
|
||||
|
||||
def __find_star_in_row(self, row):
|
||||
def __find_star_in_row(self, row: Sequence[AnyNum]) -> int:
|
||||
"""
|
||||
Find the first starred element in the specified row. Returns
|
||||
the column index, or -1 if no starred element was found.
|
||||
@@ -623,7 +389,7 @@ class Munkres:
|
||||
|
||||
return col
|
||||
|
||||
def __find_star_in_col(self, col):
|
||||
def __find_star_in_col(self, col: Sequence[AnyNum]) -> int:
|
||||
"""
|
||||
Find the first starred element in the specified row. Returns
|
||||
the row index, or -1 if no starred element was found.
|
||||
@@ -636,7 +402,7 @@ class Munkres:
|
||||
|
||||
return row
|
||||
|
||||
def __find_prime_in_row(self, row):
|
||||
def __find_prime_in_row(self, row) -> int:
|
||||
"""
|
||||
Find the first prime element in the specified row. Returns
|
||||
the column index, or -1 if no starred element was found.
|
||||
@@ -649,20 +415,22 @@ class Munkres:
|
||||
|
||||
return col
|
||||
|
||||
def __convert_path(self, path, count):
|
||||
def __convert_path(self,
|
||||
path: Sequence[Sequence[int]],
|
||||
count: int) -> None:
|
||||
for i in range(count+1):
|
||||
if self.marked[path[i][0]][path[i][1]] == 1:
|
||||
self.marked[path[i][0]][path[i][1]] = 0
|
||||
else:
|
||||
self.marked[path[i][0]][path[i][1]] = 1
|
||||
|
||||
def __clear_covers(self):
|
||||
def __clear_covers(self) -> None:
|
||||
"""Clear all covered matrix cells"""
|
||||
for i in range(self.n):
|
||||
self.row_covered[i] = False
|
||||
self.col_covered[i] = False
|
||||
|
||||
def __erase_primes(self):
|
||||
def __erase_primes(self) -> None:
|
||||
"""Erase all prime markings"""
|
||||
for i in range(self.n):
|
||||
for j in range(self.n):
|
||||
@@ -673,51 +441,56 @@ class Munkres:
|
||||
# Functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_cost_matrix(profit_matrix, inversion_function):
|
||||
def make_cost_matrix(
|
||||
profit_matrix: Matrix,
|
||||
inversion_function: Optional[Callable[[AnyNum], AnyNum]] = None
|
||||
) -> Matrix:
|
||||
"""
|
||||
Create a cost matrix from a profit matrix by calling
|
||||
'inversion_function' to invert each value. The inversion
|
||||
function must take one numeric argument (of any type) and return
|
||||
another numeric argument which is presumed to be the cost inverse
|
||||
of the original profit.
|
||||
Create a cost matrix from a profit matrix by calling `inversion_function()`
|
||||
to invert each value. The inversion function must take one numeric argument
|
||||
(of any type) and return another numeric argument which is presumed to be
|
||||
the cost inverse of the original profit value. If the inversion function
|
||||
is not provided, a given cell's inverted value is calculated as
|
||||
`max(matrix) - value`.
|
||||
|
||||
This is a static method. Call it like this:
|
||||
|
||||
.. python::
|
||||
|
||||
from munkres import Munkres
|
||||
cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func)
|
||||
|
||||
For example:
|
||||
|
||||
.. python::
|
||||
|
||||
from munkres import Munkres
|
||||
cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxsize - x)
|
||||
|
||||
:Parameters:
|
||||
profit_matrix : list of lists
|
||||
The matrix to convert from a profit to a cost matrix
|
||||
**Parameters**
|
||||
|
||||
inversion_function : function
|
||||
The function to use to invert each entry in the profit matrix
|
||||
- `profit_matrix` (list of lists of numbers): The matrix to convert from
|
||||
profit to cost values.
|
||||
- `inversion_function` (`function`): The function to use to invert each
|
||||
entry in the profit matrix.
|
||||
|
||||
:rtype: list of lists
|
||||
:return: The converted matrix
|
||||
**Returns**
|
||||
|
||||
A new matrix representing the inversion of `profix_matrix`.
|
||||
"""
|
||||
if not inversion_function:
|
||||
maximum = max(max(row) for row in profit_matrix)
|
||||
inversion_function = lambda x: maximum - x
|
||||
|
||||
cost_matrix = []
|
||||
for row in profit_matrix:
|
||||
cost_matrix.append([inversion_function(value) for value in row])
|
||||
return cost_matrix
|
||||
|
||||
def print_matrix(matrix, msg=None):
|
||||
def print_matrix(matrix: Matrix, msg: Optional[str] = None) -> None:
|
||||
"""
|
||||
Convenience function: Displays the contents of a matrix of integers.
|
||||
Convenience function: Displays the contents of a matrix.
|
||||
|
||||
:Parameters:
|
||||
matrix : list of lists
|
||||
Matrix to print
|
||||
**Parameters**
|
||||
|
||||
msg : str
|
||||
Optional message to print before displaying the matrix
|
||||
- `matrix` (list of lists of numbers): The matrix to print
|
||||
- `msg` (`str`): Optional message to print before displaying the matrix
|
||||
"""
|
||||
import math
|
||||
|
||||
@@ -728,16 +501,21 @@ def print_matrix(matrix, msg=None):
|
||||
width = 0
|
||||
for row in matrix:
|
||||
for val in row:
|
||||
width = max(width, int(math.log10(val)) + 1)
|
||||
if val is DISALLOWED:
|
||||
val = DISALLOWED_PRINTVAL
|
||||
width = max(width, len(str(val)))
|
||||
|
||||
# Make the format string
|
||||
format = '%%%dd' % width
|
||||
format = ('%%%d' % width)
|
||||
|
||||
# Print the matrix
|
||||
for row in matrix:
|
||||
sep = '['
|
||||
for val in row:
|
||||
sys.stdout.write(sep + format % val)
|
||||
if val is DISALLOWED:
|
||||
val = DISALLOWED_PRINTVAL
|
||||
formatted = ((format + 's') % val)
|
||||
sys.stdout.write(sep + formatted)
|
||||
sep = ', '
|
||||
sys.stdout.write(']\n')
|
||||
|
||||
@@ -767,11 +545,51 @@ if __name__ == '__main__':
|
||||
[9, 7, 4]],
|
||||
18),
|
||||
|
||||
# Square variant with floating point value
|
||||
([[10.1, 10.2, 8.3],
|
||||
[9.4, 8.5, 1.6],
|
||||
[9.7, 7.8, 4.9]],
|
||||
19.5),
|
||||
|
||||
# Rectangular variant
|
||||
([[10, 10, 8, 11],
|
||||
[9, 8, 1, 1],
|
||||
[9, 7, 4, 10]],
|
||||
15)]
|
||||
15),
|
||||
|
||||
# Rectangular variant with floating point value
|
||||
([[10.01, 10.02, 8.03, 11.04],
|
||||
[9.05, 8.06, 1.07, 1.08],
|
||||
[9.09, 7.1, 4.11, 10.12]],
|
||||
15.2),
|
||||
|
||||
# Rectangular with DISALLOWED
|
||||
([[4, 5, 6, DISALLOWED],
|
||||
[1, 9, 12, 11],
|
||||
[DISALLOWED, 5, 4, DISALLOWED],
|
||||
[12, 12, 12, 10]],
|
||||
20),
|
||||
|
||||
# Rectangular variant with DISALLOWED and floating point value
|
||||
([[4.001, 5.002, 6.003, DISALLOWED],
|
||||
[1.004, 9.005, 12.006, 11.007],
|
||||
[DISALLOWED, 5.008, 4.009, DISALLOWED],
|
||||
[12.01, 12.011, 12.012, 10.013]],
|
||||
20.028),
|
||||
|
||||
# DISALLOWED to force pairings
|
||||
([[1, DISALLOWED, DISALLOWED, DISALLOWED],
|
||||
[DISALLOWED, 2, DISALLOWED, DISALLOWED],
|
||||
[DISALLOWED, DISALLOWED, 3, DISALLOWED],
|
||||
[DISALLOWED, DISALLOWED, DISALLOWED, 4]],
|
||||
10),
|
||||
|
||||
# DISALLOWED to force pairings with floating point value
|
||||
([[1.1, DISALLOWED, DISALLOWED, DISALLOWED],
|
||||
[DISALLOWED, 2.2, DISALLOWED, DISALLOWED],
|
||||
[DISALLOWED, DISALLOWED, 3.3, DISALLOWED],
|
||||
[DISALLOWED, DISALLOWED, DISALLOWED, 4.4]],
|
||||
11.0)]
|
||||
|
||||
m = Munkres()
|
||||
for cost_matrix, expected_total in matrices:
|
||||
@@ -781,6 +599,6 @@ if __name__ == '__main__':
|
||||
for r, c in indexes:
|
||||
x = cost_matrix[r][c]
|
||||
total_cost += x
|
||||
print(('(%d, %d) -> %d' % (r, c, x)))
|
||||
print(('lowest cost=%d' % total_cost))
|
||||
print(('(%d, %d) -> %s' % (r, c, x)))
|
||||
print(('lowest cost=%s' % total_cost))
|
||||
assert expected_total == total_cost
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,608 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2005-2010 ActiveState Software Inc.
|
||||
# Copyright (c) 2013 Eddy Petrișor
|
||||
|
||||
"""Utilities for determining application-specific dirs.
|
||||
|
||||
See <http://github.com/ActiveState/appdirs> for details and usage.
|
||||
"""
|
||||
# Dev Notes:
|
||||
# - MSDN on where to store app data files:
|
||||
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
|
||||
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
||||
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
||||
|
||||
__version_info__ = (1, 4, 3)
|
||||
__version__ = '.'.join(map(str, __version_info__))
|
||||
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
unicode = str
|
||||
|
||||
if sys.platform.startswith('java'):
|
||||
import platform
|
||||
os_name = platform.java_ver()[3][0]
|
||||
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
|
||||
system = 'win32'
|
||||
elif os_name.startswith('Mac'): # "Mac OS X", etc.
|
||||
system = 'darwin'
|
||||
else: # "Linux", "SunOS", "FreeBSD", etc.
|
||||
# Setting this to "linux2" is not ideal, but only Windows or Mac
|
||||
# are actually checked for and the rest of the module expects
|
||||
# *sys.platform* style strings.
|
||||
system = 'linux2'
|
||||
else:
|
||||
system = sys.platform
|
||||
|
||||
|
||||
|
||||
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user data directories are:
|
||||
Mac OS X: ~/Library/Application Support/<AppName>
|
||||
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
|
||||
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
|
||||
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
|
||||
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
|
||||
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
|
||||
That means, by default "~/.local/share/<AppName>".
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
|
||||
path = os.path.normpath(_get_win_folder(const))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Application Support/')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of data dirs should be
|
||||
returned. By default, the first item from XDG_DATA_DIRS is
|
||||
returned, or '/usr/local/share/<AppName>',
|
||||
if XDG_DATA_DIRS is not set
|
||||
|
||||
Typical site data directories are:
|
||||
Mac OS X: /Library/Application Support/<AppName>
|
||||
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
|
||||
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
|
||||
|
||||
For Unix, this is using the $XDG_DATA_DIRS[0] default.
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('/Library/Application Support')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
# XDG default for $XDG_DATA_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_DATA_DIRS',
|
||||
os.pathsep.join(['/usr/local/share', '/usr/share']))
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific config dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user config directories are:
|
||||
Mac OS X: same as user_data_dir
|
||||
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
|
||||
Win *: same as user_data_dir
|
||||
|
||||
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
|
||||
That means, by default "~/.config/<AppName>".
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = user_data_dir(appname, appauthor, None, roaming)
|
||||
else:
|
||||
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"multipath" is an optional parameter only applicable to *nix
|
||||
which indicates that the entire list of config dirs should be
|
||||
returned. By default, the first item from XDG_CONFIG_DIRS is
|
||||
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
|
||||
|
||||
Typical site config directories are:
|
||||
Mac OS X: same as site_data_dir
|
||||
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
|
||||
$XDG_CONFIG_DIRS
|
||||
Win *: same as site_data_dir
|
||||
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
|
||||
|
||||
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
|
||||
|
||||
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = site_data_dir(appname, appauthor)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
else:
|
||||
# XDG default for $XDG_CONFIG_DIRS
|
||||
# only first, if multipath is False
|
||||
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
|
||||
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
|
||||
if appname:
|
||||
if version:
|
||||
appname = os.path.join(appname, version)
|
||||
pathlist = [os.sep.join([x, appname]) for x in pathlist]
|
||||
|
||||
if multipath:
|
||||
path = os.pathsep.join(pathlist)
|
||||
else:
|
||||
path = pathlist[0]
|
||||
return path
|
||||
|
||||
|
||||
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific cache dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Cache" to the base app data dir for Windows. See
|
||||
discussion below.
|
||||
|
||||
Typical user cache directories are:
|
||||
Mac OS X: ~/Library/Caches/<AppName>
|
||||
Unix: ~/.cache/<AppName> (XDG default)
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings go in
|
||||
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
|
||||
app data dir (the default returned by `user_data_dir` above). Apps typically
|
||||
put cache data somewhere *under* the given dir here. Some examples:
|
||||
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
|
||||
...\Acme\SuperApp\Cache\1.0
|
||||
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "win32":
|
||||
if appauthor is None:
|
||||
appauthor = appname
|
||||
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
|
||||
if appname:
|
||||
if appauthor is not False:
|
||||
path = os.path.join(path, appauthor, appname)
|
||||
else:
|
||||
path = os.path.join(path, appname)
|
||||
if opinion:
|
||||
path = os.path.join(path, "Cache")
|
||||
elif system == 'darwin':
|
||||
path = os.path.expanduser('~/Library/Caches')
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
else:
|
||||
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
|
||||
r"""Return full path to the user-specific state dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"roaming" (boolean, default False) can be set True to use the Windows
|
||||
roaming appdata directory. That means that for users on a Windows
|
||||
network setup for roaming profiles, this user data will be
|
||||
sync'd on login. See
|
||||
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
|
||||
for a discussion of issues.
|
||||
|
||||
Typical user state directories are:
|
||||
Mac OS X: same as user_data_dir
|
||||
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
|
||||
Win *: same as user_data_dir
|
||||
|
||||
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
|
||||
to extend the XDG spec and support $XDG_STATE_HOME.
|
||||
|
||||
That means, by default "~/.local/state/<AppName>".
|
||||
"""
|
||||
if system in ["win32", "darwin"]:
|
||||
path = user_data_dir(appname, appauthor, None, roaming)
|
||||
else:
|
||||
path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state"))
|
||||
if appname:
|
||||
path = os.path.join(path, appname)
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
|
||||
r"""Return full path to the user-specific log dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
"appauthor" (only used on Windows) is the name of the
|
||||
appauthor or distributing body for this application. Typically
|
||||
it is the owning company name. This falls back to appname. You may
|
||||
pass False to disable it.
|
||||
"version" is an optional version path element to append to the
|
||||
path. You might want to use this if you want multiple versions
|
||||
of your app to be able to run independently. If used, this
|
||||
would typically be "<major>.<minor>".
|
||||
Only applied when appname is present.
|
||||
"opinion" (boolean) can be False to disable the appending of
|
||||
"Logs" to the base app data dir for Windows, and "log" to the
|
||||
base cache dir for Unix. See discussion below.
|
||||
|
||||
Typical user log directories are:
|
||||
Mac OS X: ~/Library/Logs/<AppName>
|
||||
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
|
||||
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
|
||||
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
|
||||
|
||||
On Windows the only suggestion in the MSDN docs is that local settings
|
||||
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
|
||||
examples of what some windows apps use for a logs dir.)
|
||||
|
||||
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
|
||||
value for Windows and appends "log" to the user cache dir for Unix.
|
||||
This can be disabled with the `opinion=False` option.
|
||||
"""
|
||||
if system == "darwin":
|
||||
path = os.path.join(
|
||||
os.path.expanduser('~/Library/Logs'),
|
||||
appname)
|
||||
elif system == "win32":
|
||||
path = user_data_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "Logs")
|
||||
else:
|
||||
path = user_cache_dir(appname, appauthor, version)
|
||||
version = False
|
||||
if opinion:
|
||||
path = os.path.join(path, "log")
|
||||
if appname and version:
|
||||
path = os.path.join(path, version)
|
||||
return path
|
||||
|
||||
|
||||
class AppDirs(object):
|
||||
"""Convenience wrapper for getting application dirs."""
|
||||
def __init__(self, appname=None, appauthor=None, version=None,
|
||||
roaming=False, multipath=False):
|
||||
self.appname = appname
|
||||
self.appauthor = appauthor
|
||||
self.version = version
|
||||
self.roaming = roaming
|
||||
self.multipath = multipath
|
||||
|
||||
@property
|
||||
def user_data_dir(self):
|
||||
return user_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_data_dir(self):
|
||||
return site_data_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_config_dir(self):
|
||||
return user_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, roaming=self.roaming)
|
||||
|
||||
@property
|
||||
def site_config_dir(self):
|
||||
return site_config_dir(self.appname, self.appauthor,
|
||||
version=self.version, multipath=self.multipath)
|
||||
|
||||
@property
|
||||
def user_cache_dir(self):
|
||||
return user_cache_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
@property
|
||||
def user_state_dir(self):
|
||||
return user_state_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
@property
|
||||
def user_log_dir(self):
|
||||
return user_log_dir(self.appname, self.appauthor,
|
||||
version=self.version)
|
||||
|
||||
|
||||
#---- internal support stuff
|
||||
|
||||
def _get_win_folder_from_registry(csidl_name):
|
||||
"""This is a fallback technique at best. I'm not sure if using the
|
||||
registry for this guarantees us the correct answer for all CSIDL_*
|
||||
names.
|
||||
"""
|
||||
if PY3:
|
||||
import winreg as _winreg
|
||||
else:
|
||||
import _winreg
|
||||
|
||||
shell_folder_name = {
|
||||
"CSIDL_APPDATA": "AppData",
|
||||
"CSIDL_COMMON_APPDATA": "Common AppData",
|
||||
"CSIDL_LOCAL_APPDATA": "Local AppData",
|
||||
}[csidl_name]
|
||||
|
||||
key = _winreg.OpenKey(
|
||||
_winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
|
||||
return dir
|
||||
|
||||
|
||||
def _get_win_folder_with_pywin32(csidl_name):
|
||||
from win32com.shell import shellcon, shell
|
||||
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
|
||||
# Try to make this a unicode path because SHGetFolderPath does
|
||||
# not return unicode strings when there is unicode data in the
|
||||
# path.
|
||||
try:
|
||||
dir = unicode(dir)
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
try:
|
||||
import win32api
|
||||
dir = win32api.GetShortPathName(dir)
|
||||
except ImportError:
|
||||
pass
|
||||
except UnicodeError:
|
||||
pass
|
||||
return dir
|
||||
|
||||
|
||||
def _get_win_folder_with_ctypes(csidl_name):
|
||||
import ctypes
|
||||
|
||||
csidl_const = {
|
||||
"CSIDL_APPDATA": 26,
|
||||
"CSIDL_COMMON_APPDATA": 35,
|
||||
"CSIDL_LOCAL_APPDATA": 28,
|
||||
}[csidl_name]
|
||||
|
||||
buf = ctypes.create_unicode_buffer(1024)
|
||||
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in buf:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf2 = ctypes.create_unicode_buffer(1024)
|
||||
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
|
||||
buf = buf2
|
||||
|
||||
return buf.value
|
||||
|
||||
def _get_win_folder_with_jna(csidl_name):
|
||||
import array
|
||||
from com.sun import jna
|
||||
from com.sun.jna.platform import win32
|
||||
|
||||
buf_size = win32.WinDef.MAX_PATH * 2
|
||||
buf = array.zeros('c', buf_size)
|
||||
shell = win32.Shell32.INSTANCE
|
||||
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
# Downgrade to short path name if have highbit chars. See
|
||||
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
|
||||
has_high_char = False
|
||||
for c in dir:
|
||||
if ord(c) > 255:
|
||||
has_high_char = True
|
||||
break
|
||||
if has_high_char:
|
||||
buf = array.zeros('c', buf_size)
|
||||
kernel = win32.Kernel32.INSTANCE
|
||||
if kernel.GetShortPathName(dir, buf, buf_size):
|
||||
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
|
||||
|
||||
return dir
|
||||
|
||||
if system == "win32":
|
||||
try:
|
||||
import win32com.shell
|
||||
_get_win_folder = _get_win_folder_with_pywin32
|
||||
except ImportError:
|
||||
try:
|
||||
from ctypes import windll
|
||||
_get_win_folder = _get_win_folder_with_ctypes
|
||||
except ImportError:
|
||||
try:
|
||||
import com.sun.jna
|
||||
_get_win_folder = _get_win_folder_with_jna
|
||||
except ImportError:
|
||||
_get_win_folder = _get_win_folder_from_registry
|
||||
|
||||
|
||||
#---- self test code
|
||||
|
||||
if __name__ == "__main__":
|
||||
appname = "MyApp"
|
||||
appauthor = "MyCompany"
|
||||
|
||||
props = ("user_data_dir",
|
||||
"user_config_dir",
|
||||
"user_cache_dir",
|
||||
"user_state_dir",
|
||||
"user_log_dir",
|
||||
"site_data_dir",
|
||||
"site_config_dir")
|
||||
|
||||
print("-- app dirs %s --" % __version__)
|
||||
|
||||
print("-- app dirs (with optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor, version="1.0")
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'version')")
|
||||
dirs = AppDirs(appname, appauthor)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (without optional 'appauthor')")
|
||||
dirs = AppDirs(appname)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
|
||||
print("\n-- app dirs (with disabled 'appauthor')")
|
||||
dirs = AppDirs(appname, appauthor=False)
|
||||
for prop in props:
|
||||
print("%s: %s" % (prop, getattr(dirs, prop)))
|
||||
@@ -0,0 +1 @@
|
||||
importlib_resources
|
||||
36
lib/pkg_resources/_vendor/importlib_resources/__init__.py
Normal file
36
lib/pkg_resources/_vendor/importlib_resources/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Read resources contained within a package."""
|
||||
|
||||
from ._common import (
|
||||
as_file,
|
||||
files,
|
||||
Package,
|
||||
)
|
||||
|
||||
from ._legacy import (
|
||||
contents,
|
||||
open_binary,
|
||||
read_binary,
|
||||
open_text,
|
||||
read_text,
|
||||
is_resource,
|
||||
path,
|
||||
Resource,
|
||||
)
|
||||
|
||||
from .abc import ResourceReader
|
||||
|
||||
|
||||
__all__ = [
|
||||
'Package',
|
||||
'Resource',
|
||||
'ResourceReader',
|
||||
'as_file',
|
||||
'contents',
|
||||
'files',
|
||||
'is_resource',
|
||||
'open_binary',
|
||||
'open_text',
|
||||
'path',
|
||||
'read_binary',
|
||||
'read_text',
|
||||
]
|
||||
170
lib/pkg_resources/_vendor/importlib_resources/_adapters.py
Normal file
170
lib/pkg_resources/_vendor/importlib_resources/_adapters.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from contextlib import suppress
|
||||
from io import TextIOWrapper
|
||||
|
||||
from . import abc
|
||||
|
||||
|
||||
class SpecLoaderAdapter:
|
||||
"""
|
||||
Adapt a package spec to adapt the underlying loader.
|
||||
"""
|
||||
|
||||
def __init__(self, spec, adapter=lambda spec: spec.loader):
|
||||
self.spec = spec
|
||||
self.loader = adapter(spec)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.spec, name)
|
||||
|
||||
|
||||
class TraversableResourcesLoader:
|
||||
"""
|
||||
Adapt a loader to provide TraversableResources.
|
||||
"""
|
||||
|
||||
def __init__(self, spec):
|
||||
self.spec = spec
|
||||
|
||||
def get_resource_reader(self, name):
|
||||
return CompatibilityFiles(self.spec)._native()
|
||||
|
||||
|
||||
def _io_wrapper(file, mode='r', *args, **kwargs):
|
||||
if mode == 'r':
|
||||
return TextIOWrapper(file, *args, **kwargs)
|
||||
elif mode == 'rb':
|
||||
return file
|
||||
raise ValueError(
|
||||
"Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode)
|
||||
)
|
||||
|
||||
|
||||
class CompatibilityFiles:
|
||||
"""
|
||||
Adapter for an existing or non-existent resource reader
|
||||
to provide a compatibility .files().
|
||||
"""
|
||||
|
||||
class SpecPath(abc.Traversable):
|
||||
"""
|
||||
Path tied to a module spec.
|
||||
Can be read and exposes the resource reader children.
|
||||
"""
|
||||
|
||||
def __init__(self, spec, reader):
|
||||
self._spec = spec
|
||||
self._reader = reader
|
||||
|
||||
def iterdir(self):
|
||||
if not self._reader:
|
||||
return iter(())
|
||||
return iter(
|
||||
CompatibilityFiles.ChildPath(self._reader, path)
|
||||
for path in self._reader.contents()
|
||||
)
|
||||
|
||||
def is_file(self):
|
||||
return False
|
||||
|
||||
is_dir = is_file
|
||||
|
||||
def joinpath(self, other):
|
||||
if not self._reader:
|
||||
return CompatibilityFiles.OrphanPath(other)
|
||||
return CompatibilityFiles.ChildPath(self._reader, other)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._spec.name
|
||||
|
||||
def open(self, mode='r', *args, **kwargs):
|
||||
return _io_wrapper(self._reader.open_resource(None), mode, *args, **kwargs)
|
||||
|
||||
class ChildPath(abc.Traversable):
|
||||
"""
|
||||
Path tied to a resource reader child.
|
||||
Can be read but doesn't expose any meaningful children.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, name):
|
||||
self._reader = reader
|
||||
self._name = name
|
||||
|
||||
def iterdir(self):
|
||||
return iter(())
|
||||
|
||||
def is_file(self):
|
||||
return self._reader.is_resource(self.name)
|
||||
|
||||
def is_dir(self):
|
||||
return not self.is_file()
|
||||
|
||||
def joinpath(self, other):
|
||||
return CompatibilityFiles.OrphanPath(self.name, other)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def open(self, mode='r', *args, **kwargs):
|
||||
return _io_wrapper(
|
||||
self._reader.open_resource(self.name), mode, *args, **kwargs
|
||||
)
|
||||
|
||||
class OrphanPath(abc.Traversable):
|
||||
"""
|
||||
Orphan path, not tied to a module spec or resource reader.
|
||||
Can't be read and doesn't expose any meaningful children.
|
||||
"""
|
||||
|
||||
def __init__(self, *path_parts):
|
||||
if len(path_parts) < 1:
|
||||
raise ValueError('Need at least one path part to construct a path')
|
||||
self._path = path_parts
|
||||
|
||||
def iterdir(self):
|
||||
return iter(())
|
||||
|
||||
def is_file(self):
|
||||
return False
|
||||
|
||||
is_dir = is_file
|
||||
|
||||
def joinpath(self, other):
|
||||
return CompatibilityFiles.OrphanPath(*self._path, other)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._path[-1]
|
||||
|
||||
def open(self, mode='r', *args, **kwargs):
|
||||
raise FileNotFoundError("Can't open orphan path")
|
||||
|
||||
def __init__(self, spec):
|
||||
self.spec = spec
|
||||
|
||||
@property
|
||||
def _reader(self):
|
||||
with suppress(AttributeError):
|
||||
return self.spec.loader.get_resource_reader(self.spec.name)
|
||||
|
||||
def _native(self):
|
||||
"""
|
||||
Return the native reader if it supports files().
|
||||
"""
|
||||
reader = self._reader
|
||||
return reader if hasattr(reader, 'files') else self
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self._reader, attr)
|
||||
|
||||
def files(self):
|
||||
return CompatibilityFiles.SpecPath(self.spec, self._reader)
|
||||
|
||||
|
||||
def wrap_spec(package):
|
||||
"""
|
||||
Construct a package spec with traversable compatibility
|
||||
on the spec/loader/reader.
|
||||
"""
|
||||
return SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)
|
||||
207
lib/pkg_resources/_vendor/importlib_resources/_common.py
Normal file
207
lib/pkg_resources/_vendor/importlib_resources/_common.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import os
|
||||
import pathlib
|
||||
import tempfile
|
||||
import functools
|
||||
import contextlib
|
||||
import types
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
import itertools
|
||||
|
||||
from typing import Union, Optional, cast
|
||||
from .abc import ResourceReader, Traversable
|
||||
|
||||
from ._compat import wrap_spec
|
||||
|
||||
Package = Union[types.ModuleType, str]
|
||||
Anchor = Package
|
||||
|
||||
|
||||
def package_to_anchor(func):
|
||||
"""
|
||||
Replace 'package' parameter as 'anchor' and warn about the change.
|
||||
|
||||
Other errors should fall through.
|
||||
|
||||
>>> files('a', 'b')
|
||||
Traceback (most recent call last):
|
||||
TypeError: files() takes from 0 to 1 positional arguments but 2 were given
|
||||
"""
|
||||
undefined = object()
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(anchor=undefined, package=undefined):
|
||||
if package is not undefined:
|
||||
if anchor is not undefined:
|
||||
return func(anchor, package)
|
||||
warnings.warn(
|
||||
"First parameter to files is renamed to 'anchor'",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return func(package)
|
||||
elif anchor is undefined:
|
||||
return func()
|
||||
return func(anchor)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@package_to_anchor
|
||||
def files(anchor: Optional[Anchor] = None) -> Traversable:
|
||||
"""
|
||||
Get a Traversable resource for an anchor.
|
||||
"""
|
||||
return from_package(resolve(anchor))
|
||||
|
||||
|
||||
def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
|
||||
"""
|
||||
Return the package's loader if it's a ResourceReader.
|
||||
"""
|
||||
# We can't use
|
||||
# a issubclass() check here because apparently abc.'s __subclasscheck__()
|
||||
# hook wants to create a weak reference to the object, but
|
||||
# zipimport.zipimporter does not support weak references, resulting in a
|
||||
# TypeError. That seems terrible.
|
||||
spec = package.__spec__
|
||||
reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore
|
||||
if reader is None:
|
||||
return None
|
||||
return reader(spec.name) # type: ignore
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
|
||||
return cast(types.ModuleType, cand)
|
||||
|
||||
|
||||
@resolve.register
|
||||
def _(cand: str) -> types.ModuleType:
|
||||
return importlib.import_module(cand)
|
||||
|
||||
|
||||
@resolve.register
|
||||
def _(cand: None) -> types.ModuleType:
|
||||
return resolve(_infer_caller().f_globals['__name__'])
|
||||
|
||||
|
||||
def _infer_caller():
|
||||
"""
|
||||
Walk the stack and find the frame of the first caller not in this module.
|
||||
"""
|
||||
|
||||
def is_this_file(frame_info):
|
||||
return frame_info.filename == __file__
|
||||
|
||||
def is_wrapper(frame_info):
|
||||
return frame_info.function == 'wrapper'
|
||||
|
||||
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
|
||||
# also exclude 'wrapper' due to singledispatch in the call stack
|
||||
callers = itertools.filterfalse(is_wrapper, not_this_file)
|
||||
return next(callers).frame
|
||||
|
||||
|
||||
def from_package(package: types.ModuleType):
|
||||
"""
|
||||
Return a Traversable object for the given package.
|
||||
|
||||
"""
|
||||
spec = wrap_spec(package)
|
||||
reader = spec.loader.get_resource_reader(spec.name)
|
||||
return reader.files()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _tempfile(
|
||||
reader,
|
||||
suffix='',
|
||||
# gh-93353: Keep a reference to call os.remove() in late Python
|
||||
# finalization.
|
||||
*,
|
||||
_os_remove=os.remove,
|
||||
):
|
||||
# Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
|
||||
# blocks due to the need to close the temporary file to work on Windows
|
||||
# properly.
|
||||
fd, raw_path = tempfile.mkstemp(suffix=suffix)
|
||||
try:
|
||||
try:
|
||||
os.write(fd, reader())
|
||||
finally:
|
||||
os.close(fd)
|
||||
del reader
|
||||
yield pathlib.Path(raw_path)
|
||||
finally:
|
||||
try:
|
||||
_os_remove(raw_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def _temp_file(path):
|
||||
return _tempfile(path.read_bytes, suffix=path.name)
|
||||
|
||||
|
||||
def _is_present_dir(path: Traversable) -> bool:
|
||||
"""
|
||||
Some Traversables implement ``is_dir()`` to raise an
|
||||
exception (i.e. ``FileNotFoundError``) when the
|
||||
directory doesn't exist. This function wraps that call
|
||||
to always return a boolean and only return True
|
||||
if there's a dir and it exists.
|
||||
"""
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
return path.is_dir()
|
||||
return False
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def as_file(path):
|
||||
"""
|
||||
Given a Traversable object, return that object as a
|
||||
path on the local file system in a context manager.
|
||||
"""
|
||||
return _temp_dir(path) if _is_present_dir(path) else _temp_file(path)
|
||||
|
||||
|
||||
@as_file.register(pathlib.Path)
|
||||
@contextlib.contextmanager
|
||||
def _(path):
|
||||
"""
|
||||
Degenerate behavior for pathlib.Path objects.
|
||||
"""
|
||||
yield path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temp_path(dir: tempfile.TemporaryDirectory):
|
||||
"""
|
||||
Wrap tempfile.TemporyDirectory to return a pathlib object.
|
||||
"""
|
||||
with dir as result:
|
||||
yield pathlib.Path(result)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _temp_dir(path):
|
||||
"""
|
||||
Given a traversable dir, recursively replicate the whole tree
|
||||
to the file system in a context manager.
|
||||
"""
|
||||
assert path.is_dir()
|
||||
with _temp_path(tempfile.TemporaryDirectory()) as temp_dir:
|
||||
yield _write_contents(temp_dir, path)
|
||||
|
||||
|
||||
def _write_contents(target, source):
|
||||
child = target.joinpath(source.name)
|
||||
if source.is_dir():
|
||||
child.mkdir()
|
||||
for item in source.iterdir():
|
||||
_write_contents(child, item)
|
||||
else:
|
||||
child.write_bytes(source.read_bytes())
|
||||
return child
|
||||
108
lib/pkg_resources/_vendor/importlib_resources/_compat.py
Normal file
108
lib/pkg_resources/_vendor/importlib_resources/_compat.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# flake8: noqa
|
||||
|
||||
import abc
|
||||
import os
|
||||
import sys
|
||||
import pathlib
|
||||
from contextlib import suppress
|
||||
from typing import Union
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from zipfile import Path as ZipPath # type: ignore
|
||||
else:
|
||||
from ..zipp import Path as ZipPath # type: ignore
|
||||
|
||||
|
||||
try:
|
||||
from typing import runtime_checkable # type: ignore
|
||||
except ImportError:
|
||||
|
||||
def runtime_checkable(cls): # type: ignore
|
||||
return cls
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol # type: ignore
|
||||
except ImportError:
|
||||
Protocol = abc.ABC # type: ignore
|
||||
|
||||
|
||||
class TraversableResourcesLoader:
|
||||
"""
|
||||
Adapt loaders to provide TraversableResources and other
|
||||
compatibility.
|
||||
|
||||
Used primarily for Python 3.9 and earlier where the native
|
||||
loaders do not yet implement TraversableResources.
|
||||
"""
|
||||
|
||||
def __init__(self, spec):
|
||||
self.spec = spec
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self.spec.origin
|
||||
|
||||
def get_resource_reader(self, name):
|
||||
from . import readers, _adapters
|
||||
|
||||
def _zip_reader(spec):
|
||||
with suppress(AttributeError):
|
||||
return readers.ZipReader(spec.loader, spec.name)
|
||||
|
||||
def _namespace_reader(spec):
|
||||
with suppress(AttributeError, ValueError):
|
||||
return readers.NamespaceReader(spec.submodule_search_locations)
|
||||
|
||||
def _available_reader(spec):
|
||||
with suppress(AttributeError):
|
||||
return spec.loader.get_resource_reader(spec.name)
|
||||
|
||||
def _native_reader(spec):
|
||||
reader = _available_reader(spec)
|
||||
return reader if hasattr(reader, 'files') else None
|
||||
|
||||
def _file_reader(spec):
|
||||
try:
|
||||
path = pathlib.Path(self.path)
|
||||
except TypeError:
|
||||
return None
|
||||
if path.exists():
|
||||
return readers.FileReader(self)
|
||||
|
||||
return (
|
||||
# native reader if it supplies 'files'
|
||||
_native_reader(self.spec)
|
||||
or
|
||||
# local ZipReader if a zip module
|
||||
_zip_reader(self.spec)
|
||||
or
|
||||
# local NamespaceReader if a namespace module
|
||||
_namespace_reader(self.spec)
|
||||
or
|
||||
# local FileReader
|
||||
_file_reader(self.spec)
|
||||
# fallback - adapt the spec ResourceReader to TraversableReader
|
||||
or _adapters.CompatibilityFiles(self.spec)
|
||||
)
|
||||
|
||||
|
||||
def wrap_spec(package):
|
||||
"""
|
||||
Construct a package spec with traversable compatibility
|
||||
on the spec/loader/reader.
|
||||
|
||||
Supersedes _adapters.wrap_spec to use TraversableResourcesLoader
|
||||
from above for older Python compatibility (<3.10).
|
||||
"""
|
||||
from . import _adapters
|
||||
|
||||
return _adapters.SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
StrPath = Union[str, os.PathLike[str]]
|
||||
else:
|
||||
# PathLike is only subscriptable at runtime in 3.9+
|
||||
StrPath = Union[str, "os.PathLike[str]"]
|
||||
35
lib/pkg_resources/_vendor/importlib_resources/_itertools.py
Normal file
35
lib/pkg_resources/_vendor/importlib_resources/_itertools.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from itertools import filterfalse
|
||||
|
||||
from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
# Type and type variable definitions
|
||||
_T = TypeVar('_T')
|
||||
_U = TypeVar('_U')
|
||||
|
||||
|
||||
def unique_everseen(
|
||||
iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None
|
||||
) -> Iterator[_T]:
|
||||
"List unique elements, preserving order. Remember all elements ever seen."
|
||||
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
|
||||
# unique_everseen('ABBCcAD', str.lower) --> A B C D
|
||||
seen: Set[Union[_T, _U]] = set()
|
||||
seen_add = seen.add
|
||||
if key is None:
|
||||
for element in filterfalse(seen.__contains__, iterable):
|
||||
seen_add(element)
|
||||
yield element
|
||||
else:
|
||||
for element in iterable:
|
||||
k = key(element)
|
||||
if k not in seen:
|
||||
seen_add(k)
|
||||
yield element
|
||||
120
lib/pkg_resources/_vendor/importlib_resources/_legacy.py
Normal file
120
lib/pkg_resources/_vendor/importlib_resources/_legacy.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import functools
|
||||
import os
|
||||
import pathlib
|
||||
import types
|
||||
import warnings
|
||||
|
||||
from typing import Union, Iterable, ContextManager, BinaryIO, TextIO, Any
|
||||
|
||||
from . import _common
|
||||
|
||||
Package = Union[types.ModuleType, str]
|
||||
Resource = str
|
||||
|
||||
|
||||
def deprecated(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
warnings.warn(
|
||||
f"{func.__name__} is deprecated. Use files() instead. "
|
||||
"Refer to https://importlib-resources.readthedocs.io"
|
||||
"/en/latest/using.html#migrating-from-legacy for migration advice.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def normalize_path(path: Any) -> str:
|
||||
"""Normalize a path by ensuring it is a string.
|
||||
|
||||
If the resulting string contains path separators, an exception is raised.
|
||||
"""
|
||||
str_path = str(path)
|
||||
parent, file_name = os.path.split(str_path)
|
||||
if parent:
|
||||
raise ValueError(f'{path!r} must be only a file name')
|
||||
return file_name
|
||||
|
||||
|
||||
@deprecated
|
||||
def open_binary(package: Package, resource: Resource) -> BinaryIO:
|
||||
"""Return a file-like object opened for binary reading of the resource."""
|
||||
return (_common.files(package) / normalize_path(resource)).open('rb')
|
||||
|
||||
|
||||
@deprecated
|
||||
def read_binary(package: Package, resource: Resource) -> bytes:
|
||||
"""Return the binary contents of the resource."""
|
||||
return (_common.files(package) / normalize_path(resource)).read_bytes()
|
||||
|
||||
|
||||
@deprecated
|
||||
def open_text(
|
||||
package: Package,
|
||||
resource: Resource,
|
||||
encoding: str = 'utf-8',
|
||||
errors: str = 'strict',
|
||||
) -> TextIO:
|
||||
"""Return a file-like object opened for text reading of the resource."""
|
||||
return (_common.files(package) / normalize_path(resource)).open(
|
||||
'r', encoding=encoding, errors=errors
|
||||
)
|
||||
|
||||
|
||||
@deprecated
|
||||
def read_text(
|
||||
package: Package,
|
||||
resource: Resource,
|
||||
encoding: str = 'utf-8',
|
||||
errors: str = 'strict',
|
||||
) -> str:
|
||||
"""Return the decoded string of the resource.
|
||||
|
||||
The decoding-related arguments have the same semantics as those of
|
||||
bytes.decode().
|
||||
"""
|
||||
with open_text(package, resource, encoding, errors) as fp:
|
||||
return fp.read()
|
||||
|
||||
|
||||
@deprecated
|
||||
def contents(package: Package) -> Iterable[str]:
|
||||
"""Return an iterable of entries in `package`.
|
||||
|
||||
Note that not all entries are resources. Specifically, directories are
|
||||
not considered resources. Use `is_resource()` on each entry returned here
|
||||
to check if it is a resource or not.
|
||||
"""
|
||||
return [path.name for path in _common.files(package).iterdir()]
|
||||
|
||||
|
||||
@deprecated
|
||||
def is_resource(package: Package, name: str) -> bool:
|
||||
"""True if `name` is a resource inside `package`.
|
||||
|
||||
Directories are *not* resources.
|
||||
"""
|
||||
resource = normalize_path(name)
|
||||
return any(
|
||||
traversable.name == resource and traversable.is_file()
|
||||
for traversable in _common.files(package).iterdir()
|
||||
)
|
||||
|
||||
|
||||
@deprecated
|
||||
def path(
|
||||
package: Package,
|
||||
resource: Resource,
|
||||
) -> ContextManager[pathlib.Path]:
|
||||
"""A context manager providing a file path object to the resource.
|
||||
|
||||
If the resource does not already exist on its own on the file system,
|
||||
a temporary file will be created. If the file was created, the file
|
||||
will be deleted upon exiting the context manager (no exception is
|
||||
raised if the file was deleted prior to the context manager
|
||||
exiting).
|
||||
"""
|
||||
return _common.as_file(_common.files(package) / normalize_path(resource))
|
||||
170
lib/pkg_resources/_vendor/importlib_resources/abc.py
Normal file
170
lib/pkg_resources/_vendor/importlib_resources/abc.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import abc
|
||||
import io
|
||||
import itertools
|
||||
import pathlib
|
||||
from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional
|
||||
|
||||
from ._compat import runtime_checkable, Protocol, StrPath
|
||||
|
||||
|
||||
__all__ = ["ResourceReader", "Traversable", "TraversableResources"]
|
||||
|
||||
|
||||
class ResourceReader(metaclass=abc.ABCMeta):
|
||||
"""Abstract base class for loaders to provide resource reading support."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def open_resource(self, resource: Text) -> BinaryIO:
|
||||
"""Return an opened, file-like object for binary reading.
|
||||
|
||||
The 'resource' argument is expected to represent only a file name.
|
||||
If the resource cannot be found, FileNotFoundError is raised.
|
||||
"""
|
||||
# This deliberately raises FileNotFoundError instead of
|
||||
# NotImplementedError so that if this method is accidentally called,
|
||||
# it'll still do the right thing.
|
||||
raise FileNotFoundError
|
||||
|
||||
@abc.abstractmethod
|
||||
def resource_path(self, resource: Text) -> Text:
|
||||
"""Return the file system path to the specified resource.
|
||||
|
||||
The 'resource' argument is expected to represent only a file name.
|
||||
If the resource does not exist on the file system, raise
|
||||
FileNotFoundError.
|
||||
"""
|
||||
# This deliberately raises FileNotFoundError instead of
|
||||
# NotImplementedError so that if this method is accidentally called,
|
||||
# it'll still do the right thing.
|
||||
raise FileNotFoundError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_resource(self, path: Text) -> bool:
|
||||
"""Return True if the named 'path' is a resource.
|
||||
|
||||
Files are resources, directories are not.
|
||||
"""
|
||||
raise FileNotFoundError
|
||||
|
||||
@abc.abstractmethod
|
||||
def contents(self) -> Iterable[str]:
|
||||
"""Return an iterable of entries in `package`."""
|
||||
raise FileNotFoundError
|
||||
|
||||
|
||||
class TraversalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Traversable(Protocol):
|
||||
"""
|
||||
An object with a subset of pathlib.Path methods suitable for
|
||||
traversing directories and opening files.
|
||||
|
||||
Any exceptions that occur when accessing the backing resource
|
||||
may propagate unaltered.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def iterdir(self) -> Iterator["Traversable"]:
|
||||
"""
|
||||
Yield Traversable objects in self
|
||||
"""
|
||||
|
||||
def read_bytes(self) -> bytes:
|
||||
"""
|
||||
Read contents of self as bytes
|
||||
"""
|
||||
with self.open('rb') as strm:
|
||||
return strm.read()
|
||||
|
||||
def read_text(self, encoding: Optional[str] = None) -> str:
|
||||
"""
|
||||
Read contents of self as text
|
||||
"""
|
||||
with self.open(encoding=encoding) as strm:
|
||||
return strm.read()
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dir(self) -> bool:
|
||||
"""
|
||||
Return True if self is a directory
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_file(self) -> bool:
|
||||
"""
|
||||
Return True if self is a file
|
||||
"""
|
||||
|
||||
def joinpath(self, *descendants: StrPath) -> "Traversable":
|
||||
"""
|
||||
Return Traversable resolved with any descendants applied.
|
||||
|
||||
Each descendant should be a path segment relative to self
|
||||
and each may contain multiple levels separated by
|
||||
``posixpath.sep`` (``/``).
|
||||
"""
|
||||
if not descendants:
|
||||
return self
|
||||
names = itertools.chain.from_iterable(
|
||||
path.parts for path in map(pathlib.PurePosixPath, descendants)
|
||||
)
|
||||
target = next(names)
|
||||
matches = (
|
||||
traversable for traversable in self.iterdir() if traversable.name == target
|
||||
)
|
||||
try:
|
||||
match = next(matches)
|
||||
except StopIteration:
|
||||
raise TraversalError(
|
||||
"Target not found during traversal.", target, list(names)
|
||||
)
|
||||
return match.joinpath(*names)
|
||||
|
||||
def __truediv__(self, child: StrPath) -> "Traversable":
|
||||
"""
|
||||
Return Traversable child in self
|
||||
"""
|
||||
return self.joinpath(child)
|
||||
|
||||
@abc.abstractmethod
|
||||
def open(self, mode='r', *args, **kwargs):
|
||||
"""
|
||||
mode may be 'r' or 'rb' to open as text or binary. Return a handle
|
||||
suitable for reading (same as pathlib.Path.open).
|
||||
|
||||
When opening as text, accepts encoding parameters such as those
|
||||
accepted by io.TextIOWrapper.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
"""
|
||||
The base name of this object without any parent references.
|
||||
"""
|
||||
|
||||
|
||||
class TraversableResources(ResourceReader):
|
||||
"""
|
||||
The required interface for providing traversable
|
||||
resources.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def files(self) -> "Traversable":
|
||||
"""Return a Traversable object for the loaded package."""
|
||||
|
||||
def open_resource(self, resource: StrPath) -> io.BufferedReader:
|
||||
return self.files().joinpath(resource).open('rb')
|
||||
|
||||
def resource_path(self, resource: Any) -> NoReturn:
|
||||
raise FileNotFoundError(resource)
|
||||
|
||||
def is_resource(self, path: StrPath) -> bool:
|
||||
return self.files().joinpath(path).is_file()
|
||||
|
||||
def contents(self) -> Iterator[str]:
|
||||
return (item.name for item in self.files().iterdir())
|
||||
120
lib/pkg_resources/_vendor/importlib_resources/readers.py
Normal file
120
lib/pkg_resources/_vendor/importlib_resources/readers.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import collections
|
||||
import pathlib
|
||||
import operator
|
||||
|
||||
from . import abc
|
||||
|
||||
from ._itertools import unique_everseen
|
||||
from ._compat import ZipPath
|
||||
|
||||
|
||||
def remove_duplicates(items):
|
||||
return iter(collections.OrderedDict.fromkeys(items))
|
||||
|
||||
|
||||
class FileReader(abc.TraversableResources):
|
||||
def __init__(self, loader):
|
||||
self.path = pathlib.Path(loader.path).parent
|
||||
|
||||
def resource_path(self, resource):
|
||||
"""
|
||||
Return the file system path to prevent
|
||||
`resources.path()` from creating a temporary
|
||||
copy.
|
||||
"""
|
||||
return str(self.path.joinpath(resource))
|
||||
|
||||
def files(self):
|
||||
return self.path
|
||||
|
||||
|
||||
class ZipReader(abc.TraversableResources):
|
||||
def __init__(self, loader, module):
|
||||
_, _, name = module.rpartition('.')
|
||||
self.prefix = loader.prefix.replace('\\', '/') + name + '/'
|
||||
self.archive = loader.archive
|
||||
|
||||
def open_resource(self, resource):
|
||||
try:
|
||||
return super().open_resource(resource)
|
||||
except KeyError as exc:
|
||||
raise FileNotFoundError(exc.args[0])
|
||||
|
||||
def is_resource(self, path):
|
||||
# workaround for `zipfile.Path.is_file` returning true
|
||||
# for non-existent paths.
|
||||
target = self.files().joinpath(path)
|
||||
return target.is_file() and target.exists()
|
||||
|
||||
def files(self):
|
||||
return ZipPath(self.archive, self.prefix)
|
||||
|
||||
|
||||
class MultiplexedPath(abc.Traversable):
|
||||
"""
|
||||
Given a series of Traversable objects, implement a merged
|
||||
version of the interface across all objects. Useful for
|
||||
namespace packages which may be multihomed at a single
|
||||
name.
|
||||
"""
|
||||
|
||||
def __init__(self, *paths):
|
||||
self._paths = list(map(pathlib.Path, remove_duplicates(paths)))
|
||||
if not self._paths:
|
||||
message = 'MultiplexedPath must contain at least one path'
|
||||
raise FileNotFoundError(message)
|
||||
if not all(path.is_dir() for path in self._paths):
|
||||
raise NotADirectoryError('MultiplexedPath only supports directories')
|
||||
|
||||
def iterdir(self):
|
||||
files = (file for path in self._paths for file in path.iterdir())
|
||||
return unique_everseen(files, key=operator.attrgetter('name'))
|
||||
|
||||
def read_bytes(self):
|
||||
raise FileNotFoundError(f'{self} is not a file')
|
||||
|
||||
def read_text(self, *args, **kwargs):
|
||||
raise FileNotFoundError(f'{self} is not a file')
|
||||
|
||||
def is_dir(self):
|
||||
return True
|
||||
|
||||
def is_file(self):
|
||||
return False
|
||||
|
||||
def joinpath(self, *descendants):
|
||||
try:
|
||||
return super().joinpath(*descendants)
|
||||
except abc.TraversalError:
|
||||
# One of the paths did not resolve (a directory does not exist).
|
||||
# Just return something that will not exist.
|
||||
return self._paths[0].joinpath(*descendants)
|
||||
|
||||
def open(self, *args, **kwargs):
|
||||
raise FileNotFoundError(f'{self} is not a file')
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._paths[0].name
|
||||
|
||||
def __repr__(self):
|
||||
paths = ', '.join(f"'{path}'" for path in self._paths)
|
||||
return f'MultiplexedPath({paths})'
|
||||
|
||||
|
||||
class NamespaceReader(abc.TraversableResources):
|
||||
def __init__(self, namespace_path):
|
||||
if 'NamespacePath' not in str(namespace_path):
|
||||
raise ValueError('Invalid path')
|
||||
self.path = MultiplexedPath(*list(namespace_path))
|
||||
|
||||
def resource_path(self, resource):
|
||||
"""
|
||||
Return the file system path to prevent
|
||||
`resources.path()` from creating a temporary
|
||||
copy.
|
||||
"""
|
||||
return str(self.path.joinpath(resource))
|
||||
|
||||
def files(self):
|
||||
return self.path
|
||||
106
lib/pkg_resources/_vendor/importlib_resources/simple.py
Normal file
106
lib/pkg_resources/_vendor/importlib_resources/simple.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Interface adapters for low-level readers.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import io
|
||||
import itertools
|
||||
from typing import BinaryIO, List
|
||||
|
||||
from .abc import Traversable, TraversableResources
|
||||
|
||||
|
||||
class SimpleReader(abc.ABC):
|
||||
"""
|
||||
The minimum, low-level interface required from a resource
|
||||
provider.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def package(self) -> str:
|
||||
"""
|
||||
The name of the package for which this reader loads resources.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def children(self) -> List['SimpleReader']:
|
||||
"""
|
||||
Obtain an iterable of SimpleReader for available
|
||||
child containers (e.g. directories).
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def resources(self) -> List[str]:
|
||||
"""
|
||||
Obtain available named resources for this virtual package.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def open_binary(self, resource: str) -> BinaryIO:
|
||||
"""
|
||||
Obtain a File-like for a named resource.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.package.split('.')[-1]
|
||||
|
||||
|
||||
class ResourceContainer(Traversable):
|
||||
"""
|
||||
Traversable container for a package's resources via its reader.
|
||||
"""
|
||||
|
||||
def __init__(self, reader: SimpleReader):
|
||||
self.reader = reader
|
||||
|
||||
def is_dir(self):
|
||||
return True
|
||||
|
||||
def is_file(self):
|
||||
return False
|
||||
|
||||
def iterdir(self):
|
||||
files = (ResourceHandle(self, name) for name in self.reader.resources)
|
||||
dirs = map(ResourceContainer, self.reader.children())
|
||||
return itertools.chain(files, dirs)
|
||||
|
||||
def open(self, *args, **kwargs):
|
||||
raise IsADirectoryError()
|
||||
|
||||
|
||||
class ResourceHandle(Traversable):
|
||||
"""
|
||||
Handle to a named resource in a ResourceReader.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: ResourceContainer, name: str):
|
||||
self.parent = parent
|
||||
self.name = name # type: ignore
|
||||
|
||||
def is_file(self):
|
||||
return True
|
||||
|
||||
def is_dir(self):
|
||||
return False
|
||||
|
||||
def open(self, mode='r', *args, **kwargs):
|
||||
stream = self.parent.reader.open_binary(self.name)
|
||||
if 'b' not in mode:
|
||||
stream = io.TextIOWrapper(*args, **kwargs)
|
||||
return stream
|
||||
|
||||
def joinpath(self, name):
|
||||
raise RuntimeError("Cannot traverse into a resource")
|
||||
|
||||
|
||||
class TraversableReader(TraversableResources, SimpleReader):
|
||||
"""
|
||||
A TraversableResources based on SimpleReader. Resource providers
|
||||
may derive from this class to provide the TraversableResources
|
||||
interface by supplying the SimpleReader interface.
|
||||
"""
|
||||
|
||||
def files(self):
|
||||
return ResourceContainer(self)
|
||||
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
|
||||
try:
|
||||
from test.support import import_helper # type: ignore
|
||||
except ImportError:
|
||||
# Python 3.9 and earlier
|
||||
class import_helper: # type: ignore
|
||||
from test.support import (
|
||||
modules_setup,
|
||||
modules_cleanup,
|
||||
DirsOnSysPath,
|
||||
CleanImport,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from test.support import os_helper # type: ignore
|
||||
except ImportError:
|
||||
# Python 3.9 compat
|
||||
class os_helper: # type:ignore
|
||||
from test.support import temp_dir
|
||||
|
||||
|
||||
try:
|
||||
# Python 3.10
|
||||
from test.support.os_helper import unlink
|
||||
except ImportError:
|
||||
from test.support import unlink as _unlink
|
||||
|
||||
def unlink(target):
|
||||
return _unlink(os.fspath(target))
|
||||
50
lib/pkg_resources/_vendor/importlib_resources/tests/_path.py
Normal file
50
lib/pkg_resources/_vendor/importlib_resources/tests/_path.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pathlib
|
||||
import functools
|
||||
|
||||
|
||||
####
|
||||
# from jaraco.path 3.4
|
||||
|
||||
|
||||
def build(spec, prefix=pathlib.Path()):
|
||||
"""
|
||||
Build a set of files/directories, as described by the spec.
|
||||
|
||||
Each key represents a pathname, and the value represents
|
||||
the content. Content may be a nested directory.
|
||||
|
||||
>>> spec = {
|
||||
... 'README.txt': "A README file",
|
||||
... "foo": {
|
||||
... "__init__.py": "",
|
||||
... "bar": {
|
||||
... "__init__.py": "",
|
||||
... },
|
||||
... "baz.py": "# Some code",
|
||||
... }
|
||||
... }
|
||||
>>> tmpdir = getfixture('tmpdir')
|
||||
>>> build(spec, tmpdir)
|
||||
"""
|
||||
for name, contents in spec.items():
|
||||
create(contents, pathlib.Path(prefix) / name)
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def create(content, path):
|
||||
path.mkdir(exist_ok=True)
|
||||
build(content, prefix=path) # type: ignore
|
||||
|
||||
|
||||
@create.register
|
||||
def _(content: bytes, path):
|
||||
path.write_bytes(content)
|
||||
|
||||
|
||||
@create.register
|
||||
def _(content: str, path):
|
||||
path.write_text(content)
|
||||
|
||||
|
||||
# end from jaraco.path
|
||||
####
|
||||
@@ -0,0 +1 @@
|
||||
one resource
|
||||
@@ -0,0 +1 @@
|
||||
two resource
|
||||
@@ -0,0 +1,102 @@
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import importlib_resources as resources
|
||||
|
||||
from importlib_resources._adapters import (
|
||||
CompatibilityFiles,
|
||||
wrap_spec,
|
||||
)
|
||||
|
||||
from . import util
|
||||
|
||||
|
||||
class CompatibilityFilesTests(unittest.TestCase):
|
||||
@property
|
||||
def package(self):
|
||||
bytes_data = io.BytesIO(b'Hello, world!')
|
||||
return util.create_package(
|
||||
file=bytes_data,
|
||||
path='some_path',
|
||||
contents=('a', 'b', 'c'),
|
||||
)
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return resources.files(self.package)
|
||||
|
||||
def test_spec_path_iter(self):
|
||||
self.assertEqual(
|
||||
sorted(path.name for path in self.files.iterdir()),
|
||||
['a', 'b', 'c'],
|
||||
)
|
||||
|
||||
def test_child_path_iter(self):
|
||||
self.assertEqual(list((self.files / 'a').iterdir()), [])
|
||||
|
||||
def test_orphan_path_iter(self):
|
||||
self.assertEqual(list((self.files / 'a' / 'a').iterdir()), [])
|
||||
self.assertEqual(list((self.files / 'a' / 'a' / 'a').iterdir()), [])
|
||||
|
||||
def test_spec_path_is(self):
|
||||
self.assertFalse(self.files.is_file())
|
||||
self.assertFalse(self.files.is_dir())
|
||||
|
||||
def test_child_path_is(self):
|
||||
self.assertTrue((self.files / 'a').is_file())
|
||||
self.assertFalse((self.files / 'a').is_dir())
|
||||
|
||||
def test_orphan_path_is(self):
|
||||
self.assertFalse((self.files / 'a' / 'a').is_file())
|
||||
self.assertFalse((self.files / 'a' / 'a').is_dir())
|
||||
self.assertFalse((self.files / 'a' / 'a' / 'a').is_file())
|
||||
self.assertFalse((self.files / 'a' / 'a' / 'a').is_dir())
|
||||
|
||||
def test_spec_path_name(self):
|
||||
self.assertEqual(self.files.name, 'testingpackage')
|
||||
|
||||
def test_child_path_name(self):
|
||||
self.assertEqual((self.files / 'a').name, 'a')
|
||||
|
||||
def test_orphan_path_name(self):
|
||||
self.assertEqual((self.files / 'a' / 'b').name, 'b')
|
||||
self.assertEqual((self.files / 'a' / 'b' / 'c').name, 'c')
|
||||
|
||||
def test_spec_path_open(self):
|
||||
self.assertEqual(self.files.read_bytes(), b'Hello, world!')
|
||||
self.assertEqual(self.files.read_text(), 'Hello, world!')
|
||||
|
||||
def test_child_path_open(self):
|
||||
self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!')
|
||||
self.assertEqual((self.files / 'a').read_text(), 'Hello, world!')
|
||||
|
||||
def test_orphan_path_open(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
(self.files / 'a' / 'b').read_bytes()
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
(self.files / 'a' / 'b' / 'c').read_bytes()
|
||||
|
||||
def test_open_invalid_mode(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.files.open('0')
|
||||
|
||||
def test_orphan_path_invalid(self):
|
||||
with self.assertRaises(ValueError):
|
||||
CompatibilityFiles.OrphanPath()
|
||||
|
||||
def test_wrap_spec(self):
|
||||
spec = wrap_spec(self.package)
|
||||
self.assertIsInstance(spec.loader.get_resource_reader(None), CompatibilityFiles)
|
||||
|
||||
|
||||
class CompatibilityFilesNoReaderTests(unittest.TestCase):
|
||||
@property
|
||||
def package(self):
|
||||
return util.create_package_from_loader(None)
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return resources.files(self.package)
|
||||
|
||||
def test_spec_path_joinpath(self):
|
||||
self.assertIsInstance(self.files / 'a', CompatibilityFiles.OrphanPath)
|
||||
@@ -0,0 +1,43 @@
|
||||
import unittest
|
||||
import importlib_resources as resources
|
||||
|
||||
from . import data01
|
||||
from . import util
|
||||
|
||||
|
||||
class ContentsTests:
|
||||
expected = {
|
||||
'__init__.py',
|
||||
'binary.file',
|
||||
'subdirectory',
|
||||
'utf-16.file',
|
||||
'utf-8.file',
|
||||
}
|
||||
|
||||
def test_contents(self):
|
||||
contents = {path.name for path in resources.files(self.data).iterdir()}
|
||||
assert self.expected <= contents
|
||||
|
||||
|
||||
class ContentsDiskTests(ContentsTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.data = data01
|
||||
|
||||
|
||||
class ContentsZipTests(ContentsTests, util.ZipSetup, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class ContentsNamespaceTests(ContentsTests, unittest.TestCase):
|
||||
expected = {
|
||||
# no __init__ because of namespace design
|
||||
# no subdirectory as incidental difference in fixture
|
||||
'binary.file',
|
||||
'utf-16.file',
|
||||
'utf-8.file',
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
from . import namespacedata01
|
||||
|
||||
self.data = namespacedata01
|
||||
@@ -0,0 +1,112 @@
|
||||
import typing
|
||||
import textwrap
|
||||
import unittest
|
||||
import warnings
|
||||
import importlib
|
||||
import contextlib
|
||||
|
||||
import importlib_resources as resources
|
||||
from ..abc import Traversable
|
||||
from . import data01
|
||||
from . import util
|
||||
from . import _path
|
||||
from ._compat import os_helper, import_helper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_known_deprecation():
|
||||
with warnings.catch_warnings(record=True) as ctx:
|
||||
warnings.simplefilter('default', category=DeprecationWarning)
|
||||
yield ctx
|
||||
|
||||
|
||||
class FilesTests:
|
||||
def test_read_bytes(self):
|
||||
files = resources.files(self.data)
|
||||
actual = files.joinpath('utf-8.file').read_bytes()
|
||||
assert actual == b'Hello, UTF-8 world!\n'
|
||||
|
||||
def test_read_text(self):
|
||||
files = resources.files(self.data)
|
||||
actual = files.joinpath('utf-8.file').read_text(encoding='utf-8')
|
||||
assert actual == 'Hello, UTF-8 world!\n'
|
||||
|
||||
@unittest.skipUnless(
|
||||
hasattr(typing, 'runtime_checkable'),
|
||||
"Only suitable when typing supports runtime_checkable",
|
||||
)
|
||||
def test_traversable(self):
|
||||
assert isinstance(resources.files(self.data), Traversable)
|
||||
|
||||
def test_old_parameter(self):
|
||||
"""
|
||||
Files used to take a 'package' parameter. Make sure anyone
|
||||
passing by name is still supported.
|
||||
"""
|
||||
with suppress_known_deprecation():
|
||||
resources.files(package=self.data)
|
||||
|
||||
|
||||
class OpenDiskTests(FilesTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.data = data01
|
||||
|
||||
|
||||
class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
class OpenNamespaceTests(FilesTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
from . import namespacedata01
|
||||
|
||||
self.data = namespacedata01
|
||||
|
||||
|
||||
class SiteDir:
|
||||
def setUp(self):
|
||||
self.fixtures = contextlib.ExitStack()
|
||||
self.addCleanup(self.fixtures.close)
|
||||
self.site_dir = self.fixtures.enter_context(os_helper.temp_dir())
|
||||
self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir))
|
||||
self.fixtures.enter_context(import_helper.CleanImport())
|
||||
|
||||
|
||||
class ModulesFilesTests(SiteDir, unittest.TestCase):
|
||||
def test_module_resources(self):
|
||||
"""
|
||||
A module can have resources found adjacent to the module.
|
||||
"""
|
||||
spec = {
|
||||
'mod.py': '',
|
||||
'res.txt': 'resources are the best',
|
||||
}
|
||||
_path.build(spec, self.site_dir)
|
||||
import mod
|
||||
|
||||
actual = resources.files(mod).joinpath('res.txt').read_text()
|
||||
assert actual == spec['res.txt']
|
||||
|
||||
|
||||
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
|
||||
def test_implicit_files(self):
|
||||
"""
|
||||
Without any parameter, files() will infer the location as the caller.
|
||||
"""
|
||||
spec = {
|
||||
'somepkg': {
|
||||
'__init__.py': textwrap.dedent(
|
||||
"""
|
||||
import importlib_resources as res
|
||||
val = res.files().joinpath('res.txt').read_text()
|
||||
"""
|
||||
),
|
||||
'res.txt': 'resources are the best',
|
||||
},
|
||||
}
|
||||
_path.build(spec, self.site_dir)
|
||||
assert importlib.import_module('somepkg').val == 'resources are the best'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,81 @@
|
||||
import unittest
|
||||
|
||||
import importlib_resources as resources
|
||||
from . import data01
|
||||
from . import util
|
||||
|
||||
|
||||
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
|
||||
def execute(self, package, path):
|
||||
target = resources.files(package).joinpath(path)
|
||||
with target.open('rb'):
|
||||
pass
|
||||
|
||||
|
||||
class CommonTextTests(util.CommonTests, unittest.TestCase):
|
||||
def execute(self, package, path):
|
||||
target = resources.files(package).joinpath(path)
|
||||
with target.open():
|
||||
pass
|
||||
|
||||
|
||||
class OpenTests:
|
||||
def test_open_binary(self):
|
||||
target = resources.files(self.data) / 'binary.file'
|
||||
with target.open('rb') as fp:
|
||||
result = fp.read()
|
||||
self.assertEqual(result, b'\x00\x01\x02\x03')
|
||||
|
||||
def test_open_text_default_encoding(self):
|
||||
target = resources.files(self.data) / 'utf-8.file'
|
||||
with target.open() as fp:
|
||||
result = fp.read()
|
||||
self.assertEqual(result, 'Hello, UTF-8 world!\n')
|
||||
|
||||
def test_open_text_given_encoding(self):
|
||||
target = resources.files(self.data) / 'utf-16.file'
|
||||
with target.open(encoding='utf-16', errors='strict') as fp:
|
||||
result = fp.read()
|
||||
self.assertEqual(result, 'Hello, UTF-16 world!\n')
|
||||
|
||||
def test_open_text_with_errors(self):
|
||||
# Raises UnicodeError without the 'errors' argument.
|
||||
target = resources.files(self.data) / 'utf-16.file'
|
||||
with target.open(encoding='utf-8', errors='strict') as fp:
|
||||
self.assertRaises(UnicodeError, fp.read)
|
||||
with target.open(encoding='utf-8', errors='ignore') as fp:
|
||||
result = fp.read()
|
||||
self.assertEqual(
|
||||
result,
|
||||
'H\x00e\x00l\x00l\x00o\x00,\x00 '
|
||||
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
|
||||
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
|
||||
)
|
||||
|
||||
def test_open_binary_FileNotFoundError(self):
|
||||
target = resources.files(self.data) / 'does-not-exist'
|
||||
self.assertRaises(FileNotFoundError, target.open, 'rb')
|
||||
|
||||
def test_open_text_FileNotFoundError(self):
|
||||
target = resources.files(self.data) / 'does-not-exist'
|
||||
self.assertRaises(FileNotFoundError, target.open)
|
||||
|
||||
|
||||
class OpenDiskTests(OpenTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.data = data01
|
||||
|
||||
|
||||
class OpenDiskNamespaceTests(OpenTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
from . import namespacedata01
|
||||
|
||||
self.data = namespacedata01
|
||||
|
||||
|
||||
class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,64 @@
|
||||
import io
|
||||
import unittest
|
||||
|
||||
import importlib_resources as resources
|
||||
from . import data01
|
||||
from . import util
|
||||
|
||||
|
||||
class CommonTests(util.CommonTests, unittest.TestCase):
|
||||
def execute(self, package, path):
|
||||
with resources.as_file(resources.files(package).joinpath(path)):
|
||||
pass
|
||||
|
||||
|
||||
class PathTests:
|
||||
def test_reading(self):
|
||||
# Path should be readable.
|
||||
# Test also implicitly verifies the returned object is a pathlib.Path
|
||||
# instance.
|
||||
target = resources.files(self.data) / 'utf-8.file'
|
||||
with resources.as_file(target) as path:
|
||||
self.assertTrue(path.name.endswith("utf-8.file"), repr(path))
|
||||
# pathlib.Path.read_text() was introduced in Python 3.5.
|
||||
with path.open('r', encoding='utf-8') as file:
|
||||
text = file.read()
|
||||
self.assertEqual('Hello, UTF-8 world!\n', text)
|
||||
|
||||
|
||||
class PathDiskTests(PathTests, unittest.TestCase):
|
||||
data = data01
|
||||
|
||||
def test_natural_path(self):
|
||||
"""
|
||||
Guarantee the internal implementation detail that
|
||||
file-system-backed resources do not get the tempdir
|
||||
treatment.
|
||||
"""
|
||||
target = resources.files(self.data) / 'utf-8.file'
|
||||
with resources.as_file(target) as path:
|
||||
assert 'data' in str(path)
|
||||
|
||||
|
||||
class PathMemoryTests(PathTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
file = io.BytesIO(b'Hello, UTF-8 world!\n')
|
||||
self.addCleanup(file.close)
|
||||
self.data = util.create_package(
|
||||
file=file, path=FileNotFoundError("package exists only in memory")
|
||||
)
|
||||
self.data.__spec__.origin = None
|
||||
self.data.__spec__.has_location = False
|
||||
|
||||
|
||||
class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase):
|
||||
def test_remove_in_context_manager(self):
|
||||
# It is not an error if the file that was temporarily stashed on the
|
||||
# file system is removed inside the `with` stanza.
|
||||
target = resources.files(self.data) / 'utf-8.file'
|
||||
with resources.as_file(target) as path:
|
||||
path.unlink()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,76 @@
|
||||
import unittest
|
||||
import importlib_resources as resources
|
||||
|
||||
from . import data01
|
||||
from . import util
|
||||
from importlib import import_module
|
||||
|
||||
|
||||
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
|
||||
def execute(self, package, path):
|
||||
resources.files(package).joinpath(path).read_bytes()
|
||||
|
||||
|
||||
class CommonTextTests(util.CommonTests, unittest.TestCase):
|
||||
def execute(self, package, path):
|
||||
resources.files(package).joinpath(path).read_text()
|
||||
|
||||
|
||||
class ReadTests:
|
||||
def test_read_bytes(self):
|
||||
result = resources.files(self.data).joinpath('binary.file').read_bytes()
|
||||
self.assertEqual(result, b'\0\1\2\3')
|
||||
|
||||
def test_read_text_default_encoding(self):
|
||||
result = resources.files(self.data).joinpath('utf-8.file').read_text()
|
||||
self.assertEqual(result, 'Hello, UTF-8 world!\n')
|
||||
|
||||
def test_read_text_given_encoding(self):
|
||||
result = (
|
||||
resources.files(self.data)
|
||||
.joinpath('utf-16.file')
|
||||
.read_text(encoding='utf-16')
|
||||
)
|
||||
self.assertEqual(result, 'Hello, UTF-16 world!\n')
|
||||
|
||||
def test_read_text_with_errors(self):
|
||||
# Raises UnicodeError without the 'errors' argument.
|
||||
target = resources.files(self.data) / 'utf-16.file'
|
||||
self.assertRaises(UnicodeError, target.read_text, encoding='utf-8')
|
||||
result = target.read_text(encoding='utf-8', errors='ignore')
|
||||
self.assertEqual(
|
||||
result,
|
||||
'H\x00e\x00l\x00l\x00o\x00,\x00 '
|
||||
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
|
||||
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
|
||||
)
|
||||
|
||||
|
||||
class ReadDiskTests(ReadTests, unittest.TestCase):
|
||||
data = data01
|
||||
|
||||
|
||||
class ReadZipTests(ReadTests, util.ZipSetup, unittest.TestCase):
|
||||
def test_read_submodule_resource(self):
|
||||
submodule = import_module('ziptestdata.subdirectory')
|
||||
result = resources.files(submodule).joinpath('binary.file').read_bytes()
|
||||
self.assertEqual(result, b'\0\1\2\3')
|
||||
|
||||
def test_read_submodule_resource_by_name(self):
|
||||
result = (
|
||||
resources.files('ziptestdata.subdirectory')
|
||||
.joinpath('binary.file')
|
||||
.read_bytes()
|
||||
)
|
||||
self.assertEqual(result, b'\0\1\2\3')
|
||||
|
||||
|
||||
class ReadNamespaceTests(ReadTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
from . import namespacedata01
|
||||
|
||||
self.data = namespacedata01
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,133 @@
|
||||
import os.path
|
||||
import sys
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
from importlib import import_module
|
||||
from importlib_resources.readers import MultiplexedPath, NamespaceReader
|
||||
|
||||
|
||||
class MultiplexedPathTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
path = pathlib.Path(__file__).parent / 'namespacedata01'
|
||||
cls.folder = str(path)
|
||||
|
||||
def test_init_no_paths(self):
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
MultiplexedPath()
|
||||
|
||||
def test_init_file(self):
|
||||
with self.assertRaises(NotADirectoryError):
|
||||
MultiplexedPath(os.path.join(self.folder, 'binary.file'))
|
||||
|
||||
def test_iterdir(self):
|
||||
contents = {path.name for path in MultiplexedPath(self.folder).iterdir()}
|
||||
try:
|
||||
contents.remove('__pycache__')
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
self.assertEqual(contents, {'binary.file', 'utf-16.file', 'utf-8.file'})
|
||||
|
||||
def test_iterdir_duplicate(self):
|
||||
data01 = os.path.abspath(os.path.join(__file__, '..', 'data01'))
|
||||
contents = {
|
||||
path.name for path in MultiplexedPath(self.folder, data01).iterdir()
|
||||
}
|
||||
for remove in ('__pycache__', '__init__.pyc'):
|
||||
try:
|
||||
contents.remove(remove)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
self.assertEqual(
|
||||
contents,
|
||||
{'__init__.py', 'binary.file', 'subdirectory', 'utf-16.file', 'utf-8.file'},
|
||||
)
|
||||
|
||||
def test_is_dir(self):
|
||||
self.assertEqual(MultiplexedPath(self.folder).is_dir(), True)
|
||||
|
||||
def test_is_file(self):
|
||||
self.assertEqual(MultiplexedPath(self.folder).is_file(), False)
|
||||
|
||||
def test_open_file(self):
|
||||
path = MultiplexedPath(self.folder)
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
path.read_bytes()
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
path.read_text()
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
path.open()
|
||||
|
||||
def test_join_path(self):
|
||||
prefix = os.path.abspath(os.path.join(__file__, '..'))
|
||||
data01 = os.path.join(prefix, 'data01')
|
||||
path = MultiplexedPath(self.folder, data01)
|
||||
self.assertEqual(
|
||||
str(path.joinpath('binary.file'))[len(prefix) + 1 :],
|
||||
os.path.join('namespacedata01', 'binary.file'),
|
||||
)
|
||||
self.assertEqual(
|
||||
str(path.joinpath('subdirectory'))[len(prefix) + 1 :],
|
||||
os.path.join('data01', 'subdirectory'),
|
||||
)
|
||||
self.assertEqual(
|
||||
str(path.joinpath('imaginary'))[len(prefix) + 1 :],
|
||||
os.path.join('namespacedata01', 'imaginary'),
|
||||
)
|
||||
self.assertEqual(path.joinpath(), path)
|
||||
|
||||
def test_join_path_compound(self):
|
||||
path = MultiplexedPath(self.folder)
|
||||
assert not path.joinpath('imaginary/foo.py').exists()
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(
|
||||
repr(MultiplexedPath(self.folder)),
|
||||
f"MultiplexedPath('{self.folder}')",
|
||||
)
|
||||
|
||||
def test_name(self):
|
||||
self.assertEqual(
|
||||
MultiplexedPath(self.folder).name,
|
||||
os.path.basename(self.folder),
|
||||
)
|
||||
|
||||
|
||||
class NamespaceReaderTest(unittest.TestCase):
|
||||
site_dir = str(pathlib.Path(__file__).parent)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
sys.path.append(cls.site_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
sys.path.remove(cls.site_dir)
|
||||
|
||||
def test_init_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
NamespaceReader(['path1', 'path2'])
|
||||
|
||||
def test_resource_path(self):
|
||||
namespacedata01 = import_module('namespacedata01')
|
||||
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
|
||||
|
||||
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
|
||||
self.assertEqual(
|
||||
reader.resource_path('binary.file'), os.path.join(root, 'binary.file')
|
||||
)
|
||||
self.assertEqual(
|
||||
reader.resource_path('imaginary'), os.path.join(root, 'imaginary')
|
||||
)
|
||||
|
||||
def test_files(self):
|
||||
namespacedata01 = import_module('namespacedata01')
|
||||
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
|
||||
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
|
||||
self.assertIsInstance(reader.files(), MultiplexedPath)
|
||||
self.assertEqual(repr(reader.files()), f"MultiplexedPath('{root}')")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,260 @@
|
||||
import sys
|
||||
import unittest
|
||||
import importlib_resources as resources
|
||||
import uuid
|
||||
import pathlib
|
||||
|
||||
from . import data01
|
||||
from . import zipdata01, zipdata02
|
||||
from . import util
|
||||
from importlib import import_module
|
||||
from ._compat import import_helper, unlink
|
||||
|
||||
|
||||
class ResourceTests:
|
||||
# Subclasses are expected to set the `data` attribute.
|
||||
|
||||
def test_is_file_exists(self):
|
||||
target = resources.files(self.data) / 'binary.file'
|
||||
self.assertTrue(target.is_file())
|
||||
|
||||
def test_is_file_missing(self):
|
||||
target = resources.files(self.data) / 'not-a-file'
|
||||
self.assertFalse(target.is_file())
|
||||
|
||||
def test_is_dir(self):
|
||||
target = resources.files(self.data) / 'subdirectory'
|
||||
self.assertFalse(target.is_file())
|
||||
self.assertTrue(target.is_dir())
|
||||
|
||||
|
||||
class ResourceDiskTests(ResourceTests, unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.data = data01
|
||||
|
||||
|
||||
class ResourceZipTests(ResourceTests, util.ZipSetup, unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
def names(traversable):
|
||||
return {item.name for item in traversable.iterdir()}
|
||||
|
||||
|
||||
class ResourceLoaderTests(unittest.TestCase):
|
||||
def test_resource_contents(self):
|
||||
package = util.create_package(
|
||||
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
|
||||
)
|
||||
self.assertEqual(names(resources.files(package)), {'A', 'B', 'C'})
|
||||
|
||||
def test_is_file(self):
|
||||
package = util.create_package(
|
||||
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
|
||||
)
|
||||
self.assertTrue(resources.files(package).joinpath('B').is_file())
|
||||
|
||||
def test_is_dir(self):
|
||||
package = util.create_package(
|
||||
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
|
||||
)
|
||||
self.assertTrue(resources.files(package).joinpath('D').is_dir())
|
||||
|
||||
def test_resource_missing(self):
|
||||
package = util.create_package(
|
||||
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
|
||||
)
|
||||
self.assertFalse(resources.files(package).joinpath('Z').is_file())
|
||||
|
||||
|
||||
class ResourceCornerCaseTests(unittest.TestCase):
|
||||
def test_package_has_no_reader_fallback(self):
|
||||
# Test odd ball packages which:
|
||||
# 1. Do not have a ResourceReader as a loader
|
||||
# 2. Are not on the file system
|
||||
# 3. Are not in a zip file
|
||||
module = util.create_package(
|
||||
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
|
||||
)
|
||||
# Give the module a dummy loader.
|
||||
module.__loader__ = object()
|
||||
# Give the module a dummy origin.
|
||||
module.__file__ = '/path/which/shall/not/be/named'
|
||||
module.__spec__.loader = module.__loader__
|
||||
module.__spec__.origin = module.__file__
|
||||
self.assertFalse(resources.files(module).joinpath('A').is_file())
|
||||
|
||||
|
||||
class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase):
|
||||
ZIP_MODULE = zipdata01 # type: ignore
|
||||
|
||||
def test_is_submodule_resource(self):
|
||||
submodule = import_module('ziptestdata.subdirectory')
|
||||
self.assertTrue(resources.files(submodule).joinpath('binary.file').is_file())
|
||||
|
||||
def test_read_submodule_resource_by_name(self):
|
||||
self.assertTrue(
|
||||
resources.files('ziptestdata.subdirectory')
|
||||
.joinpath('binary.file')
|
||||
.is_file()
|
||||
)
|
||||
|
||||
def test_submodule_contents(self):
|
||||
submodule = import_module('ziptestdata.subdirectory')
|
||||
self.assertEqual(
|
||||
names(resources.files(submodule)), {'__init__.py', 'binary.file'}
|
||||
)
|
||||
|
||||
def test_submodule_contents_by_name(self):
|
||||
self.assertEqual(
|
||||
names(resources.files('ziptestdata.subdirectory')),
|
||||
{'__init__.py', 'binary.file'},
|
||||
)
|
||||
|
||||
def test_as_file_directory(self):
|
||||
with resources.as_file(resources.files('ziptestdata')) as data:
|
||||
assert data.name == 'ziptestdata'
|
||||
assert data.is_dir()
|
||||
assert data.joinpath('subdirectory').is_dir()
|
||||
assert len(list(data.iterdir()))
|
||||
assert not data.parent.exists()
|
||||
|
||||
|
||||
class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase):
|
||||
ZIP_MODULE = zipdata02 # type: ignore
|
||||
|
||||
def test_unrelated_contents(self):
|
||||
"""
|
||||
Test thata zip with two unrelated subpackages return
|
||||
distinct resources. Ref python/importlib_resources#44.
|
||||
"""
|
||||
self.assertEqual(
|
||||
names(resources.files('ziptestdata.one')),
|
||||
{'__init__.py', 'resource1.txt'},
|
||||
)
|
||||
self.assertEqual(
|
||||
names(resources.files('ziptestdata.two')),
|
||||
{'__init__.py', 'resource2.txt'},
|
||||
)
|
||||
|
||||
|
||||
class DeletingZipsTest(unittest.TestCase):
|
||||
"""Having accessed resources in a zip file should not keep an open
|
||||
reference to the zip.
|
||||
"""
|
||||
|
||||
ZIP_MODULE = zipdata01
|
||||
|
||||
def setUp(self):
|
||||
modules = import_helper.modules_setup()
|
||||
self.addCleanup(import_helper.modules_cleanup, *modules)
|
||||
|
||||
data_path = pathlib.Path(self.ZIP_MODULE.__file__)
|
||||
data_dir = data_path.parent
|
||||
self.source_zip_path = data_dir / 'ziptestdata.zip'
|
||||
self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute()
|
||||
self.zip_path.write_bytes(self.source_zip_path.read_bytes())
|
||||
sys.path.append(str(self.zip_path))
|
||||
self.data = import_module('ziptestdata')
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
sys.path.remove(str(self.zip_path))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
del sys.path_importer_cache[str(self.zip_path)]
|
||||
del sys.modules[self.data.__name__]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
unlink(self.zip_path)
|
||||
except OSError:
|
||||
# If the test fails, this will probably fail too
|
||||
pass
|
||||
|
||||
def test_iterdir_does_not_keep_open(self):
|
||||
c = [item.name for item in resources.files('ziptestdata').iterdir()]
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
def test_is_file_does_not_keep_open(self):
|
||||
c = resources.files('ziptestdata').joinpath('binary.file').is_file()
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
def test_is_file_failure_does_not_keep_open(self):
|
||||
c = resources.files('ziptestdata').joinpath('not-present').is_file()
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
@unittest.skip("Desired but not supported.")
|
||||
def test_as_file_does_not_keep_open(self): # pragma: no cover
|
||||
c = resources.as_file(resources.files('ziptestdata') / 'binary.file')
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
def test_entered_path_does_not_keep_open(self):
|
||||
# This is what certifi does on import to make its bundle
|
||||
# available for the process duration.
|
||||
c = resources.as_file(
|
||||
resources.files('ziptestdata') / 'binary.file'
|
||||
).__enter__()
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
def test_read_binary_does_not_keep_open(self):
|
||||
c = resources.files('ziptestdata').joinpath('binary.file').read_bytes()
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
def test_read_text_does_not_keep_open(self):
|
||||
c = resources.files('ziptestdata').joinpath('utf-8.file').read_text()
|
||||
self.zip_path.unlink()
|
||||
del c
|
||||
|
||||
|
||||
class ResourceFromNamespaceTest01(unittest.TestCase):
|
||||
site_dir = str(pathlib.Path(__file__).parent)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
sys.path.append(cls.site_dir)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
sys.path.remove(cls.site_dir)
|
||||
|
||||
def test_is_submodule_resource(self):
|
||||
self.assertTrue(
|
||||
resources.files(import_module('namespacedata01'))
|
||||
.joinpath('binary.file')
|
||||
.is_file()
|
||||
)
|
||||
|
||||
def test_read_submodule_resource_by_name(self):
|
||||
self.assertTrue(
|
||||
resources.files('namespacedata01').joinpath('binary.file').is_file()
|
||||
)
|
||||
|
||||
def test_submodule_contents(self):
|
||||
contents = names(resources.files(import_module('namespacedata01')))
|
||||
try:
|
||||
contents.remove('__pycache__')
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'})
|
||||
|
||||
def test_submodule_contents_by_name(self):
|
||||
contents = names(resources.files('namespacedata01'))
|
||||
try:
|
||||
contents.remove('__pycache__')
|
||||
except KeyError:
|
||||
pass
|
||||
self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Generate the zip test data files.
|
||||
|
||||
Run to build the tests/zipdataNN/ziptestdata.zip files from
|
||||
files in tests/dataNN.
|
||||
|
||||
Replaces the file with the working copy, but does commit anything
|
||||
to the source repo.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import pathlib
|
||||
import zipfile
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
>>> from unittest import mock
|
||||
>>> monkeypatch = getfixture('monkeypatch')
|
||||
>>> monkeypatch.setattr(zipfile, 'ZipFile', mock.MagicMock())
|
||||
>>> print(); main() # print workaround for bpo-32509
|
||||
<BLANKLINE>
|
||||
...data01... -> ziptestdata/...
|
||||
...
|
||||
...data02... -> ziptestdata/...
|
||||
...
|
||||
"""
|
||||
suffixes = '01', '02'
|
||||
tuple(map(generate, suffixes))
|
||||
|
||||
|
||||
def generate(suffix):
|
||||
root = pathlib.Path(__file__).parent.relative_to(os.getcwd())
|
||||
zfpath = root / f'zipdata{suffix}/ziptestdata.zip'
|
||||
with zipfile.ZipFile(zfpath, 'w') as zf:
|
||||
for src, rel in walk(root / f'data{suffix}'):
|
||||
dst = 'ziptestdata' / pathlib.PurePosixPath(rel.as_posix())
|
||||
print(src, '->', dst)
|
||||
zf.write(src, dst)
|
||||
|
||||
|
||||
def walk(datapath):
|
||||
for dirpath, dirnames, filenames in os.walk(datapath):
|
||||
with contextlib.suppress(ValueError):
|
||||
dirnames.remove('__pycache__')
|
||||
for filename in filenames:
|
||||
res = pathlib.Path(dirpath) / filename
|
||||
rel = res.relative_to(datapath)
|
||||
yield res, rel
|
||||
|
||||
|
||||
__name__ == '__main__' and main()
|
||||
167
lib/pkg_resources/_vendor/importlib_resources/tests/util.py
Normal file
167
lib/pkg_resources/_vendor/importlib_resources/tests/util.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import abc
|
||||
import importlib
|
||||
import io
|
||||
import sys
|
||||
import types
|
||||
import pathlib
|
||||
|
||||
from . import data01
|
||||
from . import zipdata01
|
||||
from ..abc import ResourceReader
|
||||
from ._compat import import_helper
|
||||
|
||||
|
||||
from importlib.machinery import ModuleSpec
|
||||
|
||||
|
||||
class Reader(ResourceReader):
|
||||
def __init__(self, **kwargs):
|
||||
vars(self).update(kwargs)
|
||||
|
||||
def get_resource_reader(self, package):
|
||||
return self
|
||||
|
||||
def open_resource(self, path):
|
||||
self._path = path
|
||||
if isinstance(self.file, Exception):
|
||||
raise self.file
|
||||
return self.file
|
||||
|
||||
def resource_path(self, path_):
|
||||
self._path = path_
|
||||
if isinstance(self.path, Exception):
|
||||
raise self.path
|
||||
return self.path
|
||||
|
||||
def is_resource(self, path_):
|
||||
self._path = path_
|
||||
if isinstance(self.path, Exception):
|
||||
raise self.path
|
||||
|
||||
def part(entry):
|
||||
return entry.split('/')
|
||||
|
||||
return any(
|
||||
len(parts) == 1 and parts[0] == path_ for parts in map(part, self._contents)
|
||||
)
|
||||
|
||||
def contents(self):
|
||||
if isinstance(self.path, Exception):
|
||||
raise self.path
|
||||
yield from self._contents
|
||||
|
||||
|
||||
def create_package_from_loader(loader, is_package=True):
|
||||
name = 'testingpackage'
|
||||
module = types.ModuleType(name)
|
||||
spec = ModuleSpec(name, loader, origin='does-not-exist', is_package=is_package)
|
||||
module.__spec__ = spec
|
||||
module.__loader__ = loader
|
||||
return module
|
||||
|
||||
|
||||
def create_package(file=None, path=None, is_package=True, contents=()):
|
||||
return create_package_from_loader(
|
||||
Reader(file=file, path=path, _contents=contents),
|
||||
is_package,
|
||||
)
|
||||
|
||||
|
||||
class CommonTests(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
Tests shared by test_open, test_path, and test_read.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def execute(self, package, path):
|
||||
"""
|
||||
Call the pertinent legacy API function (e.g. open_text, path)
|
||||
on package and path.
|
||||
"""
|
||||
|
||||
def test_package_name(self):
|
||||
# Passing in the package name should succeed.
|
||||
self.execute(data01.__name__, 'utf-8.file')
|
||||
|
||||
def test_package_object(self):
|
||||
# Passing in the package itself should succeed.
|
||||
self.execute(data01, 'utf-8.file')
|
||||
|
||||
def test_string_path(self):
|
||||
# Passing in a string for the path should succeed.
|
||||
path = 'utf-8.file'
|
||||
self.execute(data01, path)
|
||||
|
||||
def test_pathlib_path(self):
|
||||
# Passing in a pathlib.PurePath object for the path should succeed.
|
||||
path = pathlib.PurePath('utf-8.file')
|
||||
self.execute(data01, path)
|
||||
|
||||
def test_importing_module_as_side_effect(self):
|
||||
# The anchor package can already be imported.
|
||||
del sys.modules[data01.__name__]
|
||||
self.execute(data01.__name__, 'utf-8.file')
|
||||
|
||||
def test_missing_path(self):
|
||||
# Attempting to open or read or request the path for a
|
||||
# non-existent path should succeed if open_resource
|
||||
# can return a viable data stream.
|
||||
bytes_data = io.BytesIO(b'Hello, world!')
|
||||
package = create_package(file=bytes_data, path=FileNotFoundError())
|
||||
self.execute(package, 'utf-8.file')
|
||||
self.assertEqual(package.__loader__._path, 'utf-8.file')
|
||||
|
||||
def test_extant_path(self):
|
||||
# Attempting to open or read or request the path when the
|
||||
# path does exist should still succeed. Does not assert
|
||||
# anything about the result.
|
||||
bytes_data = io.BytesIO(b'Hello, world!')
|
||||
# any path that exists
|
||||
path = __file__
|
||||
package = create_package(file=bytes_data, path=path)
|
||||
self.execute(package, 'utf-8.file')
|
||||
self.assertEqual(package.__loader__._path, 'utf-8.file')
|
||||
|
||||
def test_useless_loader(self):
|
||||
package = create_package(file=FileNotFoundError(), path=FileNotFoundError())
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
self.execute(package, 'utf-8.file')
|
||||
|
||||
|
||||
class ZipSetupBase:
|
||||
ZIP_MODULE = None
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
data_path = pathlib.Path(cls.ZIP_MODULE.__file__)
|
||||
data_dir = data_path.parent
|
||||
cls._zip_path = str(data_dir / 'ziptestdata.zip')
|
||||
sys.path.append(cls._zip_path)
|
||||
cls.data = importlib.import_module('ziptestdata')
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
sys.path.remove(cls._zip_path)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
del sys.path_importer_cache[cls._zip_path]
|
||||
del sys.modules[cls.data.__name__]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
del cls.data
|
||||
del cls._zip_path
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
modules = import_helper.modules_setup()
|
||||
self.addCleanup(import_helper.modules_cleanup, *modules)
|
||||
|
||||
|
||||
class ZipSetup(ZipSetupBase):
|
||||
ZIP_MODULE = zipdata01 # type: ignore
|
||||
@@ -0,0 +1 @@
|
||||
jaraco
|
||||
@@ -0,0 +1 @@
|
||||
jaraco
|
||||
@@ -0,0 +1 @@
|
||||
jaraco
|
||||
0
lib/pkg_resources/_vendor/jaraco/__init__.py
Normal file
0
lib/pkg_resources/_vendor/jaraco/__init__.py
Normal file
288
lib/pkg_resources/_vendor/jaraco/context.py
Normal file
288
lib/pkg_resources/_vendor/jaraco/context.py
Normal file
@@ -0,0 +1,288 @@
|
||||
import os
|
||||
import subprocess
|
||||
import contextlib
|
||||
import functools
|
||||
import tempfile
|
||||
import shutil
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def pushd(dir):
|
||||
"""
|
||||
>>> tmp_path = getfixture('tmp_path')
|
||||
>>> with pushd(tmp_path):
|
||||
... assert os.getcwd() == os.fspath(tmp_path)
|
||||
>>> assert os.getcwd() != os.fspath(tmp_path)
|
||||
"""
|
||||
|
||||
orig = os.getcwd()
|
||||
os.chdir(dir)
|
||||
try:
|
||||
yield dir
|
||||
finally:
|
||||
os.chdir(orig)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
|
||||
"""
|
||||
Get a tarball, extract it, change to that directory, yield, then
|
||||
clean up.
|
||||
`runner` is the function to invoke commands.
|
||||
`pushd` is a context manager for changing the directory.
|
||||
"""
|
||||
if target_dir is None:
|
||||
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
|
||||
if runner is None:
|
||||
runner = functools.partial(subprocess.check_call, shell=True)
|
||||
else:
|
||||
warnings.warn("runner parameter is deprecated", DeprecationWarning)
|
||||
# In the tar command, use --strip-components=1 to strip the first path and
|
||||
# then
|
||||
# use -C to cause the files to be extracted to {target_dir}. This ensures
|
||||
# that we always know where the files were extracted.
|
||||
runner('mkdir {target_dir}'.format(**vars()))
|
||||
try:
|
||||
getter = 'wget {url} -O -'
|
||||
extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
|
||||
cmd = ' | '.join((getter, extract))
|
||||
runner(cmd.format(compression=infer_compression(url), **vars()))
|
||||
with pushd(target_dir):
|
||||
yield target_dir
|
||||
finally:
|
||||
runner('rm -Rf {target_dir}'.format(**vars()))
|
||||
|
||||
|
||||
def infer_compression(url):
|
||||
"""
|
||||
Given a URL or filename, infer the compression code for tar.
|
||||
|
||||
>>> infer_compression('http://foo/bar.tar.gz')
|
||||
'z'
|
||||
>>> infer_compression('http://foo/bar.tgz')
|
||||
'z'
|
||||
>>> infer_compression('file.bz')
|
||||
'j'
|
||||
>>> infer_compression('file.xz')
|
||||
'J'
|
||||
"""
|
||||
# cheat and just assume it's the last two characters
|
||||
compression_indicator = url[-2:]
|
||||
mapping = dict(gz='z', bz='j', xz='J')
|
||||
# Assume 'z' (gzip) if no match
|
||||
return mapping.get(compression_indicator, 'z')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temp_dir(remover=shutil.rmtree):
|
||||
"""
|
||||
Create a temporary directory context. Pass a custom remover
|
||||
to override the removal behavior.
|
||||
|
||||
>>> import pathlib
|
||||
>>> with temp_dir() as the_dir:
|
||||
... assert os.path.isdir(the_dir)
|
||||
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
|
||||
>>> assert not os.path.exists(the_dir)
|
||||
"""
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
yield temp_dir
|
||||
finally:
|
||||
remover(temp_dir)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
|
||||
"""
|
||||
Check out the repo indicated by url.
|
||||
|
||||
If dest_ctx is supplied, it should be a context manager
|
||||
to yield the target directory for the check out.
|
||||
"""
|
||||
exe = 'git' if 'git' in url else 'hg'
|
||||
with dest_ctx() as repo_dir:
|
||||
cmd = [exe, 'clone', url, repo_dir]
|
||||
if branch:
|
||||
cmd.extend(['--branch', branch])
|
||||
devnull = open(os.path.devnull, 'w')
|
||||
stdout = devnull if quiet else None
|
||||
subprocess.check_call(cmd, stdout=stdout)
|
||||
yield repo_dir
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def null():
|
||||
"""
|
||||
A null context suitable to stand in for a meaningful context.
|
||||
|
||||
>>> with null() as value:
|
||||
... assert value is None
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
class ExceptionTrap:
|
||||
"""
|
||||
A context manager that will catch certain exceptions and provide an
|
||||
indication they occurred.
|
||||
|
||||
>>> with ExceptionTrap() as trap:
|
||||
... raise Exception()
|
||||
>>> bool(trap)
|
||||
True
|
||||
|
||||
>>> with ExceptionTrap() as trap:
|
||||
... pass
|
||||
>>> bool(trap)
|
||||
False
|
||||
|
||||
>>> with ExceptionTrap(ValueError) as trap:
|
||||
... raise ValueError("1 + 1 is not 3")
|
||||
>>> bool(trap)
|
||||
True
|
||||
>>> trap.value
|
||||
ValueError('1 + 1 is not 3')
|
||||
>>> trap.tb
|
||||
<traceback object at ...>
|
||||
|
||||
>>> with ExceptionTrap(ValueError) as trap:
|
||||
... raise Exception()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
Exception
|
||||
|
||||
>>> bool(trap)
|
||||
False
|
||||
"""
|
||||
|
||||
exc_info = None, None, None
|
||||
|
||||
def __init__(self, exceptions=(Exception,)):
|
||||
self.exceptions = exceptions
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.exc_info[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.exc_info[1]
|
||||
|
||||
@property
|
||||
def tb(self):
|
||||
return self.exc_info[2]
|
||||
|
||||
def __exit__(self, *exc_info):
|
||||
type = exc_info[0]
|
||||
matches = type and issubclass(type, self.exceptions)
|
||||
if matches:
|
||||
self.exc_info = exc_info
|
||||
return matches
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.type)
|
||||
|
||||
def raises(self, func, *, _test=bool):
|
||||
"""
|
||||
Wrap func and replace the result with the truth
|
||||
value of the trap (True if an exception occurred).
|
||||
|
||||
First, give the decorator an alias to support Python 3.8
|
||||
Syntax.
|
||||
|
||||
>>> raises = ExceptionTrap(ValueError).raises
|
||||
|
||||
Now decorate a function that always fails.
|
||||
|
||||
>>> @raises
|
||||
... def fail():
|
||||
... raise ValueError('failed')
|
||||
>>> fail()
|
||||
True
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with ExceptionTrap(self.exceptions) as trap:
|
||||
func(*args, **kwargs)
|
||||
return _test(trap)
|
||||
|
||||
return wrapper
|
||||
|
||||
def passes(self, func):
|
||||
"""
|
||||
Wrap func and replace the result with the truth
|
||||
value of the trap (True if no exception).
|
||||
|
||||
First, give the decorator an alias to support Python 3.8
|
||||
Syntax.
|
||||
|
||||
>>> passes = ExceptionTrap(ValueError).passes
|
||||
|
||||
Now decorate a function that always fails.
|
||||
|
||||
>>> @passes
|
||||
... def fail():
|
||||
... raise ValueError('failed')
|
||||
|
||||
>>> fail()
|
||||
False
|
||||
"""
|
||||
return self.raises(func, _test=operator.not_)
|
||||
|
||||
|
||||
class suppress(contextlib.suppress, contextlib.ContextDecorator):
|
||||
"""
|
||||
A version of contextlib.suppress with decorator support.
|
||||
|
||||
>>> @suppress(KeyError)
|
||||
... def key_error():
|
||||
... {}['']
|
||||
>>> key_error()
|
||||
"""
|
||||
|
||||
|
||||
class on_interrupt(contextlib.ContextDecorator):
|
||||
"""
|
||||
Replace a KeyboardInterrupt with SystemExit(1)
|
||||
|
||||
>>> def do_interrupt():
|
||||
... raise KeyboardInterrupt()
|
||||
>>> on_interrupt('error')(do_interrupt)()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
SystemExit: 1
|
||||
>>> on_interrupt('error', code=255)(do_interrupt)()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
SystemExit: 255
|
||||
>>> on_interrupt('suppress')(do_interrupt)()
|
||||
>>> with __import__('pytest').raises(KeyboardInterrupt):
|
||||
... on_interrupt('ignore')(do_interrupt)()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action='error',
|
||||
# py3.7 compat
|
||||
# /,
|
||||
code=1,
|
||||
):
|
||||
self.action = action
|
||||
self.code = code
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exctype, excinst, exctb):
|
||||
if exctype is not KeyboardInterrupt or self.action == 'ignore':
|
||||
return
|
||||
elif self.action == 'error':
|
||||
raise SystemExit(self.code) from excinst
|
||||
return self.action == 'suppress'
|
||||
556
lib/pkg_resources/_vendor/jaraco/functools.py
Normal file
556
lib/pkg_resources/_vendor/jaraco/functools.py
Normal file
@@ -0,0 +1,556 @@
|
||||
import functools
|
||||
import time
|
||||
import inspect
|
||||
import collections
|
||||
import types
|
||||
import itertools
|
||||
import warnings
|
||||
|
||||
import pkg_resources.extern.more_itertools
|
||||
|
||||
from typing import Callable, TypeVar
|
||||
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., object])
|
||||
|
||||
|
||||
def compose(*funcs):
|
||||
"""
|
||||
Compose any number of unary functions into a single unary function.
|
||||
|
||||
>>> import textwrap
|
||||
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
|
||||
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
|
||||
>>> strip_and_dedent(compose.__doc__) == expected
|
||||
True
|
||||
|
||||
Compose also allows the innermost function to take arbitrary arguments.
|
||||
|
||||
>>> round_three = lambda x: round(x, ndigits=3)
|
||||
>>> f = compose(round_three, int.__truediv__)
|
||||
>>> [f(3*x, x+1) for x in range(1,10)]
|
||||
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
|
||||
"""
|
||||
|
||||
def compose_two(f1, f2):
|
||||
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
|
||||
|
||||
return functools.reduce(compose_two, funcs)
|
||||
|
||||
|
||||
def method_caller(method_name, *args, **kwargs):
|
||||
"""
|
||||
Return a function that will call a named method on the
|
||||
target object with optional positional and keyword
|
||||
arguments.
|
||||
|
||||
>>> lower = method_caller('lower')
|
||||
>>> lower('MyString')
|
||||
'mystring'
|
||||
"""
|
||||
|
||||
def call_method(target):
|
||||
func = getattr(target, method_name)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return call_method
|
||||
|
||||
|
||||
def once(func):
|
||||
"""
|
||||
Decorate func so it's only ever called the first time.
|
||||
|
||||
This decorator can ensure that an expensive or non-idempotent function
|
||||
will not be expensive on subsequent calls and is idempotent.
|
||||
|
||||
>>> add_three = once(lambda a: a+3)
|
||||
>>> add_three(3)
|
||||
6
|
||||
>>> add_three(9)
|
||||
6
|
||||
>>> add_three('12')
|
||||
6
|
||||
|
||||
To reset the stored value, simply clear the property ``saved_result``.
|
||||
|
||||
>>> del add_three.saved_result
|
||||
>>> add_three(9)
|
||||
12
|
||||
>>> add_three(8)
|
||||
12
|
||||
|
||||
Or invoke 'reset()' on it.
|
||||
|
||||
>>> add_three.reset()
|
||||
>>> add_three(-3)
|
||||
0
|
||||
>>> add_three(0)
|
||||
0
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not hasattr(wrapper, 'saved_result'):
|
||||
wrapper.saved_result = func(*args, **kwargs)
|
||||
return wrapper.saved_result
|
||||
|
||||
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
|
||||
return wrapper
|
||||
|
||||
|
||||
def method_cache(
|
||||
method: CallableT,
|
||||
cache_wrapper: Callable[
|
||||
[CallableT], CallableT
|
||||
] = functools.lru_cache(), # type: ignore[assignment]
|
||||
) -> CallableT:
|
||||
"""
|
||||
Wrap lru_cache to support storing the cache data in the object instances.
|
||||
|
||||
Abstracts the common paradigm where the method explicitly saves an
|
||||
underscore-prefixed protected property on first call and returns that
|
||||
subsequently.
|
||||
|
||||
>>> class MyClass:
|
||||
... calls = 0
|
||||
...
|
||||
... @method_cache
|
||||
... def method(self, value):
|
||||
... self.calls += 1
|
||||
... return value
|
||||
|
||||
>>> a = MyClass()
|
||||
>>> a.method(3)
|
||||
3
|
||||
>>> for x in range(75):
|
||||
... res = a.method(x)
|
||||
>>> a.calls
|
||||
75
|
||||
|
||||
Note that the apparent behavior will be exactly like that of lru_cache
|
||||
except that the cache is stored on each instance, so values in one
|
||||
instance will not flush values from another, and when an instance is
|
||||
deleted, so are the cached values for that instance.
|
||||
|
||||
>>> b = MyClass()
|
||||
>>> for x in range(35):
|
||||
... res = b.method(x)
|
||||
>>> b.calls
|
||||
35
|
||||
>>> a.method(0)
|
||||
0
|
||||
>>> a.calls
|
||||
75
|
||||
|
||||
Note that if method had been decorated with ``functools.lru_cache()``,
|
||||
a.calls would have been 76 (due to the cached value of 0 having been
|
||||
flushed by the 'b' instance).
|
||||
|
||||
Clear the cache with ``.cache_clear()``
|
||||
|
||||
>>> a.method.cache_clear()
|
||||
|
||||
Same for a method that hasn't yet been called.
|
||||
|
||||
>>> c = MyClass()
|
||||
>>> c.method.cache_clear()
|
||||
|
||||
Another cache wrapper may be supplied:
|
||||
|
||||
>>> cache = functools.lru_cache(maxsize=2)
|
||||
>>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache)
|
||||
>>> a = MyClass()
|
||||
>>> a.method2()
|
||||
3
|
||||
|
||||
Caution - do not subsequently wrap the method with another decorator, such
|
||||
as ``@property``, which changes the semantics of the function.
|
||||
|
||||
See also
|
||||
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
|
||||
for another implementation and additional justification.
|
||||
"""
|
||||
|
||||
def wrapper(self: object, *args: object, **kwargs: object) -> object:
|
||||
# it's the first call, replace the method with a cached, bound method
|
||||
bound_method: CallableT = types.MethodType( # type: ignore[assignment]
|
||||
method, self
|
||||
)
|
||||
cached_method = cache_wrapper(bound_method)
|
||||
setattr(self, method.__name__, cached_method)
|
||||
return cached_method(*args, **kwargs)
|
||||
|
||||
# Support cache clear even before cache has been created.
|
||||
wrapper.cache_clear = lambda: None # type: ignore[attr-defined]
|
||||
|
||||
return ( # type: ignore[return-value]
|
||||
_special_method_cache(method, cache_wrapper) or wrapper
|
||||
)
|
||||
|
||||
|
||||
def _special_method_cache(method, cache_wrapper):
|
||||
"""
|
||||
Because Python treats special methods differently, it's not
|
||||
possible to use instance attributes to implement the cached
|
||||
methods.
|
||||
|
||||
Instead, install the wrapper method under a different name
|
||||
and return a simple proxy to that wrapper.
|
||||
|
||||
https://github.com/jaraco/jaraco.functools/issues/5
|
||||
"""
|
||||
name = method.__name__
|
||||
special_names = '__getattr__', '__getitem__'
|
||||
if name not in special_names:
|
||||
return
|
||||
|
||||
wrapper_name = '__cached' + name
|
||||
|
||||
def proxy(self, *args, **kwargs):
|
||||
if wrapper_name not in vars(self):
|
||||
bound = types.MethodType(method, self)
|
||||
cache = cache_wrapper(bound)
|
||||
setattr(self, wrapper_name, cache)
|
||||
else:
|
||||
cache = getattr(self, wrapper_name)
|
||||
return cache(*args, **kwargs)
|
||||
|
||||
return proxy
|
||||
|
||||
|
||||
def apply(transform):
|
||||
"""
|
||||
Decorate a function with a transform function that is
|
||||
invoked on results returned from the decorated function.
|
||||
|
||||
>>> @apply(reversed)
|
||||
... def get_numbers(start):
|
||||
... "doc for get_numbers"
|
||||
... return range(start, start+3)
|
||||
>>> list(get_numbers(4))
|
||||
[6, 5, 4]
|
||||
>>> get_numbers.__doc__
|
||||
'doc for get_numbers'
|
||||
"""
|
||||
|
||||
def wrap(func):
|
||||
return functools.wraps(func)(compose(transform, func))
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def result_invoke(action):
|
||||
r"""
|
||||
Decorate a function with an action function that is
|
||||
invoked on the results returned from the decorated
|
||||
function (for its side-effect), then return the original
|
||||
result.
|
||||
|
||||
>>> @result_invoke(print)
|
||||
... def add_two(a, b):
|
||||
... return a + b
|
||||
>>> x = add_two(2, 3)
|
||||
5
|
||||
>>> x
|
||||
5
|
||||
"""
|
||||
|
||||
def wrap(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
action(result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def invoke(f, *args, **kwargs):
|
||||
"""
|
||||
Call a function for its side effect after initialization.
|
||||
|
||||
The benefit of using the decorator instead of simply invoking a function
|
||||
after defining it is that it makes explicit the author's intent for the
|
||||
function to be called immediately. Whereas if one simply calls the
|
||||
function immediately, it's less obvious if that was intentional or
|
||||
incidental. It also avoids repeating the name - the two actions, defining
|
||||
the function and calling it immediately are modeled separately, but linked
|
||||
by the decorator construct.
|
||||
|
||||
The benefit of having a function construct (opposed to just invoking some
|
||||
behavior inline) is to serve as a scope in which the behavior occurs. It
|
||||
avoids polluting the global namespace with local variables, provides an
|
||||
anchor on which to attach documentation (docstring), keeps the behavior
|
||||
logically separated (instead of conceptually separated or not separated at
|
||||
all), and provides potential to re-use the behavior for testing or other
|
||||
purposes.
|
||||
|
||||
This function is named as a pithy way to communicate, "call this function
|
||||
primarily for its side effect", or "while defining this function, also
|
||||
take it aside and call it". It exists because there's no Python construct
|
||||
for "define and call" (nor should there be, as decorators serve this need
|
||||
just fine). The behavior happens immediately and synchronously.
|
||||
|
||||
>>> @invoke
|
||||
... def func(): print("called")
|
||||
called
|
||||
>>> func()
|
||||
called
|
||||
|
||||
Use functools.partial to pass parameters to the initial call
|
||||
|
||||
>>> @functools.partial(invoke, name='bingo')
|
||||
... def func(name): print("called with", name)
|
||||
called with bingo
|
||||
"""
|
||||
f(*args, **kwargs)
|
||||
return f
|
||||
|
||||
|
||||
def call_aside(*args, **kwargs):
|
||||
"""
|
||||
Deprecated name for invoke.
|
||||
"""
|
||||
warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning)
|
||||
return invoke(*args, **kwargs)
|
||||
|
||||
|
||||
class Throttler:
|
||||
"""
|
||||
Rate-limit a function (or other callable)
|
||||
"""
|
||||
|
||||
def __init__(self, func, max_rate=float('Inf')):
|
||||
if isinstance(func, Throttler):
|
||||
func = func.func
|
||||
self.func = func
|
||||
self.max_rate = max_rate
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_called = 0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._wait()
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def _wait(self):
|
||||
"ensure at least 1/max_rate seconds from last call"
|
||||
elapsed = time.time() - self.last_called
|
||||
must_wait = 1 / self.max_rate - elapsed
|
||||
time.sleep(max(0, must_wait))
|
||||
self.last_called = time.time()
|
||||
|
||||
def __get__(self, obj, type=None):
|
||||
return first_invoke(self._wait, functools.partial(self.func, obj))
|
||||
|
||||
|
||||
def first_invoke(func1, func2):
|
||||
"""
|
||||
Return a function that when invoked will invoke func1 without
|
||||
any parameters (for its side-effect) and then invoke func2
|
||||
with whatever parameters were passed, returning its result.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
func1()
|
||||
return func2(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
|
||||
"""
|
||||
Given a callable func, trap the indicated exceptions
|
||||
for up to 'retries' times, invoking cleanup on the
|
||||
exception. On the final attempt, allow any exceptions
|
||||
to propagate.
|
||||
"""
|
||||
attempts = itertools.count() if retries == float('inf') else range(retries)
|
||||
for attempt in attempts:
|
||||
try:
|
||||
return func()
|
||||
except trap:
|
||||
cleanup()
|
||||
|
||||
return func()
|
||||
|
||||
|
||||
def retry(*r_args, **r_kwargs):
|
||||
"""
|
||||
Decorator wrapper for retry_call. Accepts arguments to retry_call
|
||||
except func and then returns a decorator for the decorated function.
|
||||
|
||||
Ex:
|
||||
|
||||
>>> @retry(retries=3)
|
||||
... def my_func(a, b):
|
||||
... "this is my funk"
|
||||
... print(a, b)
|
||||
>>> my_func.__doc__
|
||||
'this is my funk'
|
||||
"""
|
||||
|
||||
def decorate(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*f_args, **f_kwargs):
|
||||
bound = functools.partial(func, *f_args, **f_kwargs)
|
||||
return retry_call(bound, *r_args, **r_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def print_yielded(func):
|
||||
"""
|
||||
Convert a generator into a function that prints all yielded elements
|
||||
|
||||
>>> @print_yielded
|
||||
... def x():
|
||||
... yield 3; yield None
|
||||
>>> x()
|
||||
3
|
||||
None
|
||||
"""
|
||||
print_all = functools.partial(map, print)
|
||||
print_results = compose(more_itertools.consume, print_all, func)
|
||||
return functools.wraps(func)(print_results)
|
||||
|
||||
|
||||
def pass_none(func):
|
||||
"""
|
||||
Wrap func so it's not called if its first param is None
|
||||
|
||||
>>> print_text = pass_none(print)
|
||||
>>> print_text('text')
|
||||
text
|
||||
>>> print_text(None)
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(param, *args, **kwargs):
|
||||
if param is not None:
|
||||
return func(param, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def assign_params(func, namespace):
|
||||
"""
|
||||
Assign parameters from namespace where func solicits.
|
||||
|
||||
>>> def func(x, y=3):
|
||||
... print(x, y)
|
||||
>>> assigned = assign_params(func, dict(x=2, z=4))
|
||||
>>> assigned()
|
||||
2 3
|
||||
|
||||
The usual errors are raised if a function doesn't receive
|
||||
its required parameters:
|
||||
|
||||
>>> assigned = assign_params(func, dict(y=3, z=4))
|
||||
>>> assigned()
|
||||
Traceback (most recent call last):
|
||||
TypeError: func() ...argument...
|
||||
|
||||
It even works on methods:
|
||||
|
||||
>>> class Handler:
|
||||
... def meth(self, arg):
|
||||
... print(arg)
|
||||
>>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))()
|
||||
crystal
|
||||
"""
|
||||
sig = inspect.signature(func)
|
||||
params = sig.parameters.keys()
|
||||
call_ns = {k: namespace[k] for k in params if k in namespace}
|
||||
return functools.partial(func, **call_ns)
|
||||
|
||||
|
||||
def save_method_args(method):
|
||||
"""
|
||||
Wrap a method such that when it is called, the args and kwargs are
|
||||
saved on the method.
|
||||
|
||||
>>> class MyClass:
|
||||
... @save_method_args
|
||||
... def method(self, a, b):
|
||||
... print(a, b)
|
||||
>>> my_ob = MyClass()
|
||||
>>> my_ob.method(1, 2)
|
||||
1 2
|
||||
>>> my_ob._saved_method.args
|
||||
(1, 2)
|
||||
>>> my_ob._saved_method.kwargs
|
||||
{}
|
||||
>>> my_ob.method(a=3, b='foo')
|
||||
3 foo
|
||||
>>> my_ob._saved_method.args
|
||||
()
|
||||
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
|
||||
True
|
||||
|
||||
The arguments are stored on the instance, allowing for
|
||||
different instance to save different args.
|
||||
|
||||
>>> your_ob = MyClass()
|
||||
>>> your_ob.method({str('x'): 3}, b=[4])
|
||||
{'x': 3} [4]
|
||||
>>> your_ob._saved_method.args
|
||||
({'x': 3},)
|
||||
>>> my_ob._saved_method.args
|
||||
()
|
||||
"""
|
||||
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
attr_name = '_saved_' + method.__name__
|
||||
attr = args_and_kwargs(args, kwargs)
|
||||
setattr(self, attr_name, attr)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def except_(*exceptions, replace=None, use=None):
|
||||
"""
|
||||
Replace the indicated exceptions, if raised, with the indicated
|
||||
literal replacement or evaluated expression (if present).
|
||||
|
||||
>>> safe_int = except_(ValueError)(int)
|
||||
>>> safe_int('five')
|
||||
>>> safe_int('5')
|
||||
5
|
||||
|
||||
Specify a literal replacement with ``replace``.
|
||||
|
||||
>>> safe_int_r = except_(ValueError, replace=0)(int)
|
||||
>>> safe_int_r('five')
|
||||
0
|
||||
|
||||
Provide an expression to ``use`` to pass through particular parameters.
|
||||
|
||||
>>> safe_int_pt = except_(ValueError, use='args[0]')(int)
|
||||
>>> safe_int_pt('five')
|
||||
'five'
|
||||
|
||||
"""
|
||||
|
||||
def decorate(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions:
|
||||
try:
|
||||
return eval(use)
|
||||
except TypeError:
|
||||
return replace
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorate
|
||||
2
lib/pkg_resources/_vendor/jaraco/text/Lorem ipsum.txt
Normal file
2
lib/pkg_resources/_vendor/jaraco/text/Lorem ipsum.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
|
||||
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst.
|
||||
599
lib/pkg_resources/_vendor/jaraco/text/__init__.py
Normal file
599
lib/pkg_resources/_vendor/jaraco/text/__init__.py
Normal file
@@ -0,0 +1,599 @@
|
||||
import re
|
||||
import itertools
|
||||
import textwrap
|
||||
import functools
|
||||
|
||||
try:
|
||||
from importlib.resources import files # type: ignore
|
||||
except ImportError: # pragma: nocover
|
||||
from pkg_resources.extern.importlib_resources import files # type: ignore
|
||||
|
||||
from pkg_resources.extern.jaraco.functools import compose, method_cache
|
||||
from pkg_resources.extern.jaraco.context import ExceptionTrap
|
||||
|
||||
|
||||
def substitution(old, new):
|
||||
"""
|
||||
Return a function that will perform a substitution on a string
|
||||
"""
|
||||
return lambda s: s.replace(old, new)
|
||||
|
||||
|
||||
def multi_substitution(*substitutions):
|
||||
"""
|
||||
Take a sequence of pairs specifying substitutions, and create
|
||||
a function that performs those substitutions.
|
||||
|
||||
>>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo')
|
||||
'baz'
|
||||
"""
|
||||
substitutions = itertools.starmap(substitution, substitutions)
|
||||
# compose function applies last function first, so reverse the
|
||||
# substitutions to get the expected order.
|
||||
substitutions = reversed(tuple(substitutions))
|
||||
return compose(*substitutions)
|
||||
|
||||
|
||||
class FoldedCase(str):
|
||||
"""
|
||||
A case insensitive string class; behaves just like str
|
||||
except compares equal when the only variation is case.
|
||||
|
||||
>>> s = FoldedCase('hello world')
|
||||
|
||||
>>> s == 'Hello World'
|
||||
True
|
||||
|
||||
>>> 'Hello World' == s
|
||||
True
|
||||
|
||||
>>> s != 'Hello World'
|
||||
False
|
||||
|
||||
>>> s.index('O')
|
||||
4
|
||||
|
||||
>>> s.split('O')
|
||||
['hell', ' w', 'rld']
|
||||
|
||||
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
|
||||
['alpha', 'Beta', 'GAMMA']
|
||||
|
||||
Sequence membership is straightforward.
|
||||
|
||||
>>> "Hello World" in [s]
|
||||
True
|
||||
>>> s in ["Hello World"]
|
||||
True
|
||||
|
||||
You may test for set inclusion, but candidate and elements
|
||||
must both be folded.
|
||||
|
||||
>>> FoldedCase("Hello World") in {s}
|
||||
True
|
||||
>>> s in {FoldedCase("Hello World")}
|
||||
True
|
||||
|
||||
String inclusion works as long as the FoldedCase object
|
||||
is on the right.
|
||||
|
||||
>>> "hello" in FoldedCase("Hello World")
|
||||
True
|
||||
|
||||
But not if the FoldedCase object is on the left:
|
||||
|
||||
>>> FoldedCase('hello') in 'Hello World'
|
||||
False
|
||||
|
||||
In that case, use ``in_``:
|
||||
|
||||
>>> FoldedCase('hello').in_('Hello World')
|
||||
True
|
||||
|
||||
>>> FoldedCase('hello') > FoldedCase('Hello')
|
||||
False
|
||||
"""
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.lower() < other.lower()
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.lower() > other.lower()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.lower() == other.lower()
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.lower() != other.lower()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.lower())
|
||||
|
||||
def __contains__(self, other):
|
||||
return super().lower().__contains__(other.lower())
|
||||
|
||||
def in_(self, other):
|
||||
"Does self appear in other?"
|
||||
return self in FoldedCase(other)
|
||||
|
||||
# cache lower since it's likely to be called frequently.
|
||||
@method_cache
|
||||
def lower(self):
|
||||
return super().lower()
|
||||
|
||||
def index(self, sub):
|
||||
return self.lower().index(sub.lower())
|
||||
|
||||
def split(self, splitter=' ', maxsplit=0):
|
||||
pattern = re.compile(re.escape(splitter), re.I)
|
||||
return pattern.split(self, maxsplit)
|
||||
|
||||
|
||||
# Python 3.8 compatibility
|
||||
_unicode_trap = ExceptionTrap(UnicodeDecodeError)
|
||||
|
||||
|
||||
@_unicode_trap.passes
|
||||
def is_decodable(value):
|
||||
r"""
|
||||
Return True if the supplied value is decodable (using the default
|
||||
encoding).
|
||||
|
||||
>>> is_decodable(b'\xff')
|
||||
False
|
||||
>>> is_decodable(b'\x32')
|
||||
True
|
||||
"""
|
||||
value.decode()
|
||||
|
||||
|
||||
def is_binary(value):
|
||||
r"""
|
||||
Return True if the value appears to be binary (that is, it's a byte
|
||||
string and isn't decodable).
|
||||
|
||||
>>> is_binary(b'\xff')
|
||||
True
|
||||
>>> is_binary('\xff')
|
||||
False
|
||||
"""
|
||||
return isinstance(value, bytes) and not is_decodable(value)
|
||||
|
||||
|
||||
def trim(s):
|
||||
r"""
|
||||
Trim something like a docstring to remove the whitespace that
|
||||
is common due to indentation and formatting.
|
||||
|
||||
>>> trim("\n\tfoo = bar\n\t\tbar = baz\n")
|
||||
'foo = bar\n\tbar = baz'
|
||||
"""
|
||||
return textwrap.dedent(s).strip()
|
||||
|
||||
|
||||
def wrap(s):
|
||||
"""
|
||||
Wrap lines of text, retaining existing newlines as
|
||||
paragraph markers.
|
||||
|
||||
>>> print(wrap(lorem_ipsum))
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do
|
||||
eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad
|
||||
minim veniam, quis nostrud exercitation ullamco laboris nisi ut
|
||||
aliquip ex ea commodo consequat. Duis aute irure dolor in
|
||||
reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla
|
||||
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
|
||||
culpa qui officia deserunt mollit anim id est laborum.
|
||||
<BLANKLINE>
|
||||
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam
|
||||
varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus
|
||||
magna felis sollicitudin mauris. Integer in mauris eu nibh euismod
|
||||
gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis
|
||||
risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue,
|
||||
eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas
|
||||
fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla
|
||||
a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis,
|
||||
neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing
|
||||
sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque
|
||||
nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus
|
||||
quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis,
|
||||
molestie eu, feugiat in, orci. In hac habitasse platea dictumst.
|
||||
"""
|
||||
paragraphs = s.splitlines()
|
||||
wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs)
|
||||
return '\n\n'.join(wrapped)
|
||||
|
||||
|
||||
def unwrap(s):
|
||||
r"""
|
||||
Given a multi-line string, return an unwrapped version.
|
||||
|
||||
>>> wrapped = wrap(lorem_ipsum)
|
||||
>>> wrapped.count('\n')
|
||||
20
|
||||
>>> unwrapped = unwrap(wrapped)
|
||||
>>> unwrapped.count('\n')
|
||||
1
|
||||
>>> print(unwrapped)
|
||||
Lorem ipsum dolor sit amet, consectetur adipiscing ...
|
||||
Curabitur pretium tincidunt lacus. Nulla gravida orci ...
|
||||
|
||||
"""
|
||||
paragraphs = re.split(r'\n\n+', s)
|
||||
cleaned = (para.replace('\n', ' ') for para in paragraphs)
|
||||
return '\n'.join(cleaned)
|
||||
|
||||
|
||||
|
||||
|
||||
class Splitter(object):
|
||||
"""object that will split a string with the given arguments for each call
|
||||
|
||||
>>> s = Splitter(',')
|
||||
>>> s('hello, world, this is your, master calling')
|
||||
['hello', ' world', ' this is your', ' master calling']
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self, s):
|
||||
return s.split(*self.args)
|
||||
|
||||
|
||||
def indent(string, prefix=' ' * 4):
|
||||
"""
|
||||
>>> indent('foo')
|
||||
' foo'
|
||||
"""
|
||||
return prefix + string
|
||||
|
||||
|
||||
class WordSet(tuple):
|
||||
"""
|
||||
Given an identifier, return the words that identifier represents,
|
||||
whether in camel case, underscore-separated, etc.
|
||||
|
||||
>>> WordSet.parse("camelCase")
|
||||
('camel', 'Case')
|
||||
|
||||
>>> WordSet.parse("under_sep")
|
||||
('under', 'sep')
|
||||
|
||||
Acronyms should be retained
|
||||
|
||||
>>> WordSet.parse("firstSNL")
|
||||
('first', 'SNL')
|
||||
|
||||
>>> WordSet.parse("you_and_I")
|
||||
('you', 'and', 'I')
|
||||
|
||||
>>> WordSet.parse("A simple test")
|
||||
('A', 'simple', 'test')
|
||||
|
||||
Multiple caps should not interfere with the first cap of another word.
|
||||
|
||||
>>> WordSet.parse("myABCClass")
|
||||
('my', 'ABC', 'Class')
|
||||
|
||||
The result is a WordSet, so you can get the form you need.
|
||||
|
||||
>>> WordSet.parse("myABCClass").underscore_separated()
|
||||
'my_ABC_Class'
|
||||
|
||||
>>> WordSet.parse('a-command').camel_case()
|
||||
'ACommand'
|
||||
|
||||
>>> WordSet.parse('someIdentifier').lowered().space_separated()
|
||||
'some identifier'
|
||||
|
||||
Slices of the result should return another WordSet.
|
||||
|
||||
>>> WordSet.parse('taken-out-of-context')[1:].underscore_separated()
|
||||
'out_of_context'
|
||||
|
||||
>>> WordSet.from_class_name(WordSet()).lowered().space_separated()
|
||||
'word set'
|
||||
|
||||
>>> example = WordSet.parse('figured it out')
|
||||
>>> example.headless_camel_case()
|
||||
'figuredItOut'
|
||||
>>> example.dash_separated()
|
||||
'figured-it-out'
|
||||
|
||||
"""
|
||||
|
||||
_pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))')
|
||||
|
||||
def capitalized(self):
|
||||
return WordSet(word.capitalize() for word in self)
|
||||
|
||||
def lowered(self):
|
||||
return WordSet(word.lower() for word in self)
|
||||
|
||||
def camel_case(self):
|
||||
return ''.join(self.capitalized())
|
||||
|
||||
def headless_camel_case(self):
|
||||
words = iter(self)
|
||||
first = next(words).lower()
|
||||
new_words = itertools.chain((first,), WordSet(words).camel_case())
|
||||
return ''.join(new_words)
|
||||
|
||||
def underscore_separated(self):
|
||||
return '_'.join(self)
|
||||
|
||||
def dash_separated(self):
|
||||
return '-'.join(self)
|
||||
|
||||
def space_separated(self):
|
||||
return ' '.join(self)
|
||||
|
||||
def trim_right(self, item):
|
||||
"""
|
||||
Remove the item from the end of the set.
|
||||
|
||||
>>> WordSet.parse('foo bar').trim_right('foo')
|
||||
('foo', 'bar')
|
||||
>>> WordSet.parse('foo bar').trim_right('bar')
|
||||
('foo',)
|
||||
>>> WordSet.parse('').trim_right('bar')
|
||||
()
|
||||
"""
|
||||
return self[:-1] if self and self[-1] == item else self
|
||||
|
||||
def trim_left(self, item):
|
||||
"""
|
||||
Remove the item from the beginning of the set.
|
||||
|
||||
>>> WordSet.parse('foo bar').trim_left('foo')
|
||||
('bar',)
|
||||
>>> WordSet.parse('foo bar').trim_left('bar')
|
||||
('foo', 'bar')
|
||||
>>> WordSet.parse('').trim_left('bar')
|
||||
()
|
||||
"""
|
||||
return self[1:] if self and self[0] == item else self
|
||||
|
||||
def trim(self, item):
|
||||
"""
|
||||
>>> WordSet.parse('foo bar').trim('foo')
|
||||
('bar',)
|
||||
"""
|
||||
return self.trim_left(item).trim_right(item)
|
||||
|
||||
def __getitem__(self, item):
|
||||
result = super(WordSet, self).__getitem__(item)
|
||||
if isinstance(item, slice):
|
||||
result = WordSet(result)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def parse(cls, identifier):
|
||||
matches = cls._pattern.finditer(identifier)
|
||||
return WordSet(match.group(0) for match in matches)
|
||||
|
||||
@classmethod
|
||||
def from_class_name(cls, subject):
|
||||
return cls.parse(subject.__class__.__name__)
|
||||
|
||||
|
||||
# for backward compatibility
|
||||
words = WordSet.parse
|
||||
|
||||
|
||||
def simple_html_strip(s):
|
||||
r"""
|
||||
Remove HTML from the string `s`.
|
||||
|
||||
>>> str(simple_html_strip(''))
|
||||
''
|
||||
|
||||
>>> print(simple_html_strip('A <bold>stormy</bold> day in paradise'))
|
||||
A stormy day in paradise
|
||||
|
||||
>>> print(simple_html_strip('Somebody <!-- do not --> tell the truth.'))
|
||||
Somebody tell the truth.
|
||||
|
||||
>>> print(simple_html_strip('What about<br/>\nmultiple lines?'))
|
||||
What about
|
||||
multiple lines?
|
||||
"""
|
||||
html_stripper = re.compile('(<!--.*?-->)|(<[^>]*>)|([^<]+)', re.DOTALL)
|
||||
texts = (match.group(3) or '' for match in html_stripper.finditer(s))
|
||||
return ''.join(texts)
|
||||
|
||||
|
||||
class SeparatedValues(str):
|
||||
"""
|
||||
A string separated by a separator. Overrides __iter__ for getting
|
||||
the values.
|
||||
|
||||
>>> list(SeparatedValues('a,b,c'))
|
||||
['a', 'b', 'c']
|
||||
|
||||
Whitespace is stripped and empty values are discarded.
|
||||
|
||||
>>> list(SeparatedValues(' a, b , c, '))
|
||||
['a', 'b', 'c']
|
||||
"""
|
||||
|
||||
separator = ','
|
||||
|
||||
def __iter__(self):
|
||||
parts = self.split(self.separator)
|
||||
return filter(None, (part.strip() for part in parts))
|
||||
|
||||
|
||||
class Stripper:
|
||||
r"""
|
||||
Given a series of lines, find the common prefix and strip it from them.
|
||||
|
||||
>>> lines = [
|
||||
... 'abcdefg\n',
|
||||
... 'abc\n',
|
||||
... 'abcde\n',
|
||||
... ]
|
||||
>>> res = Stripper.strip_prefix(lines)
|
||||
>>> res.prefix
|
||||
'abc'
|
||||
>>> list(res.lines)
|
||||
['defg\n', '\n', 'de\n']
|
||||
|
||||
If no prefix is common, nothing should be stripped.
|
||||
|
||||
>>> lines = [
|
||||
... 'abcd\n',
|
||||
... '1234\n',
|
||||
... ]
|
||||
>>> res = Stripper.strip_prefix(lines)
|
||||
>>> res.prefix = ''
|
||||
>>> list(res.lines)
|
||||
['abcd\n', '1234\n']
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, lines):
|
||||
self.prefix = prefix
|
||||
self.lines = map(self, lines)
|
||||
|
||||
@classmethod
|
||||
def strip_prefix(cls, lines):
|
||||
prefix_lines, lines = itertools.tee(lines)
|
||||
prefix = functools.reduce(cls.common_prefix, prefix_lines)
|
||||
return cls(prefix, lines)
|
||||
|
||||
def __call__(self, line):
|
||||
if not self.prefix:
|
||||
return line
|
||||
null, prefix, rest = line.partition(self.prefix)
|
||||
return rest
|
||||
|
||||
@staticmethod
|
||||
def common_prefix(s1, s2):
|
||||
"""
|
||||
Return the common prefix of two lines.
|
||||
"""
|
||||
index = min(len(s1), len(s2))
|
||||
while s1[:index] != s2[:index]:
|
||||
index -= 1
|
||||
return s1[:index]
|
||||
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
"""
|
||||
Remove the prefix from the text if it exists.
|
||||
|
||||
>>> remove_prefix('underwhelming performance', 'underwhelming ')
|
||||
'performance'
|
||||
|
||||
>>> remove_prefix('something special', 'sample')
|
||||
'something special'
|
||||
"""
|
||||
null, prefix, rest = text.rpartition(prefix)
|
||||
return rest
|
||||
|
||||
|
||||
def remove_suffix(text, suffix):
|
||||
"""
|
||||
Remove the suffix from the text if it exists.
|
||||
|
||||
>>> remove_suffix('name.git', '.git')
|
||||
'name'
|
||||
|
||||
>>> remove_suffix('something special', 'sample')
|
||||
'something special'
|
||||
"""
|
||||
rest, suffix, null = text.partition(suffix)
|
||||
return rest
|
||||
|
||||
|
||||
def normalize_newlines(text):
|
||||
r"""
|
||||
Replace alternate newlines with the canonical newline.
|
||||
|
||||
>>> normalize_newlines('Lorem Ipsum\u2029')
|
||||
'Lorem Ipsum\n'
|
||||
>>> normalize_newlines('Lorem Ipsum\r\n')
|
||||
'Lorem Ipsum\n'
|
||||
>>> normalize_newlines('Lorem Ipsum\x85')
|
||||
'Lorem Ipsum\n'
|
||||
"""
|
||||
newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029']
|
||||
pattern = '|'.join(newlines)
|
||||
return re.sub(pattern, '\n', text)
|
||||
|
||||
|
||||
def _nonblank(str):
|
||||
return str and not str.startswith('#')
|
||||
|
||||
|
||||
@functools.singledispatch
|
||||
def yield_lines(iterable):
|
||||
r"""
|
||||
Yield valid lines of a string or iterable.
|
||||
|
||||
>>> list(yield_lines(''))
|
||||
[]
|
||||
>>> list(yield_lines(['foo', 'bar']))
|
||||
['foo', 'bar']
|
||||
>>> list(yield_lines('foo\nbar'))
|
||||
['foo', 'bar']
|
||||
>>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
|
||||
['foo', 'baz #comment']
|
||||
>>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
|
||||
['foo', 'bar', 'baz', 'bing']
|
||||
"""
|
||||
return itertools.chain.from_iterable(map(yield_lines, iterable))
|
||||
|
||||
|
||||
@yield_lines.register(str)
|
||||
def _(text):
|
||||
return filter(_nonblank, map(str.strip, text.splitlines()))
|
||||
|
||||
|
||||
def drop_comment(line):
|
||||
"""
|
||||
Drop comments.
|
||||
|
||||
>>> drop_comment('foo # bar')
|
||||
'foo'
|
||||
|
||||
A hash without a space may be in a URL.
|
||||
|
||||
>>> drop_comment('http://example.com/foo#bar')
|
||||
'http://example.com/foo#bar'
|
||||
"""
|
||||
return line.partition(' #')[0]
|
||||
|
||||
|
||||
def join_continuation(lines):
|
||||
r"""
|
||||
Join lines continued by a trailing backslash.
|
||||
|
||||
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
|
||||
['foobar', 'baz']
|
||||
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
|
||||
['foobar', 'baz']
|
||||
>>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
|
||||
['foobarbaz']
|
||||
|
||||
Not sure why, but...
|
||||
The character preceeding the backslash is also elided.
|
||||
|
||||
>>> list(join_continuation(['goo\\', 'dly']))
|
||||
['godly']
|
||||
|
||||
A terrible idea, but...
|
||||
If no line is available to continue, suppress the lines.
|
||||
|
||||
>>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
|
||||
['foo']
|
||||
"""
|
||||
lines = iter(lines)
|
||||
for item in lines:
|
||||
while item.endswith('\\'):
|
||||
try:
|
||||
item = item[:-2].strip() + next(lines)
|
||||
except StopIteration:
|
||||
return
|
||||
yield item
|
||||
6
lib/pkg_resources/_vendor/more_itertools/__init__.py
Normal file
6
lib/pkg_resources/_vendor/more_itertools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""More routines for operating on iterables, beyond itertools"""
|
||||
|
||||
from .more import * # noqa
|
||||
from .recipes import * # noqa
|
||||
|
||||
__version__ = '9.1.0'
|
||||
2
lib/pkg_resources/_vendor/more_itertools/__init__.pyi
Normal file
2
lib/pkg_resources/_vendor/more_itertools/__init__.pyi
Normal file
@@ -0,0 +1,2 @@
|
||||
from .more import *
|
||||
from .recipes import *
|
||||
4391
lib/pkg_resources/_vendor/more_itertools/more.py
Executable file
4391
lib/pkg_resources/_vendor/more_itertools/more.py
Executable file
File diff suppressed because it is too large
Load Diff
666
lib/pkg_resources/_vendor/more_itertools/more.pyi
Normal file
666
lib/pkg_resources/_vendor/more_itertools/more.pyi
Normal file
@@ -0,0 +1,666 @@
|
||||
"""Stubs for more_itertools.more"""
|
||||
from __future__ import annotations
|
||||
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
ContextManager,
|
||||
Generic,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
overload,
|
||||
Reversible,
|
||||
Sequence,
|
||||
Sized,
|
||||
Type,
|
||||
TypeVar,
|
||||
type_check_only,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
|
||||
# Type and type variable definitions
|
||||
_T = TypeVar('_T')
|
||||
_T1 = TypeVar('_T1')
|
||||
_T2 = TypeVar('_T2')
|
||||
_U = TypeVar('_U')
|
||||
_V = TypeVar('_V')
|
||||
_W = TypeVar('_W')
|
||||
_T_co = TypeVar('_T_co', covariant=True)
|
||||
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
|
||||
_Raisable = BaseException | Type[BaseException]
|
||||
|
||||
@type_check_only
|
||||
class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
|
||||
|
||||
@type_check_only
|
||||
class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
|
||||
|
||||
@type_check_only
|
||||
class _SupportsSlicing(Protocol[_T_co]):
|
||||
def __getitem__(self, __k: slice) -> _T_co: ...
|
||||
|
||||
def chunked(
|
||||
iterable: Iterable[_T], n: int | None, strict: bool = ...
|
||||
) -> Iterator[list[_T]]: ...
|
||||
@overload
|
||||
def first(iterable: Iterable[_T]) -> _T: ...
|
||||
@overload
|
||||
def first(iterable: Iterable[_T], default: _U) -> _T | _U: ...
|
||||
@overload
|
||||
def last(iterable: Iterable[_T]) -> _T: ...
|
||||
@overload
|
||||
def last(iterable: Iterable[_T], default: _U) -> _T | _U: ...
|
||||
@overload
|
||||
def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
|
||||
@overload
|
||||
def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
|
||||
|
||||
class peekable(Generic[_T], Iterator[_T]):
|
||||
def __init__(self, iterable: Iterable[_T]) -> None: ...
|
||||
def __iter__(self) -> peekable[_T]: ...
|
||||
def __bool__(self) -> bool: ...
|
||||
@overload
|
||||
def peek(self) -> _T: ...
|
||||
@overload
|
||||
def peek(self, default: _U) -> _T | _U: ...
|
||||
def prepend(self, *items: _T) -> None: ...
|
||||
def __next__(self) -> _T: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> _T: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> list[_T]: ...
|
||||
|
||||
def consumer(func: _GenFn) -> _GenFn: ...
|
||||
def ilen(iterable: Iterable[object]) -> int: ...
|
||||
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
|
||||
def with_iter(
|
||||
context_manager: ContextManager[Iterable[_T]],
|
||||
) -> Iterator[_T]: ...
|
||||
def one(
|
||||
iterable: Iterable[_T],
|
||||
too_short: _Raisable | None = ...,
|
||||
too_long: _Raisable | None = ...,
|
||||
) -> _T: ...
|
||||
def raise_(exception: _Raisable, *args: Any) -> None: ...
|
||||
def strictly_n(
|
||||
iterable: Iterable[_T],
|
||||
n: int,
|
||||
too_short: _GenFn | None = ...,
|
||||
too_long: _GenFn | None = ...,
|
||||
) -> list[_T]: ...
|
||||
def distinct_permutations(
|
||||
iterable: Iterable[_T], r: int | None = ...
|
||||
) -> Iterator[tuple[_T, ...]]: ...
|
||||
def intersperse(
|
||||
e: _U, iterable: Iterable[_T], n: int = ...
|
||||
) -> Iterator[_T | _U]: ...
|
||||
def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ...
|
||||
@overload
|
||||
def windowed(
|
||||
seq: Iterable[_T], n: int, *, step: int = ...
|
||||
) -> Iterator[tuple[_T | None, ...]]: ...
|
||||
@overload
|
||||
def windowed(
|
||||
seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
|
||||
) -> Iterator[tuple[_T | _U, ...]]: ...
|
||||
def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
||||
def substrings_indexes(
|
||||
seq: Sequence[_T], reverse: bool = ...
|
||||
) -> Iterator[tuple[Sequence[_T], int, int]]: ...
|
||||
|
||||
class bucket(Generic[_T, _U], Container[_U]):
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Iterable[_T],
|
||||
key: Callable[[_T], _U],
|
||||
validator: Callable[[object], object] | None = ...,
|
||||
) -> None: ...
|
||||
def __contains__(self, value: object) -> bool: ...
|
||||
def __iter__(self) -> Iterator[_U]: ...
|
||||
def __getitem__(self, value: object) -> Iterator[_T]: ...
|
||||
|
||||
def spy(
|
||||
iterable: Iterable[_T], n: int = ...
|
||||
) -> tuple[list[_T], Iterator[_T]]: ...
|
||||
def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
||||
def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
||||
def interleave_evenly(
|
||||
iterables: list[Iterable[_T]], lengths: list[int] | None = ...
|
||||
) -> Iterator[_T]: ...
|
||||
def collapse(
|
||||
iterable: Iterable[Any],
|
||||
base_type: type | None = ...,
|
||||
levels: int | None = ...,
|
||||
) -> Iterator[Any]: ...
|
||||
@overload
|
||||
def side_effect(
|
||||
func: Callable[[_T], object],
|
||||
iterable: Iterable[_T],
|
||||
chunk_size: None = ...,
|
||||
before: Callable[[], object] | None = ...,
|
||||
after: Callable[[], object] | None = ...,
|
||||
) -> Iterator[_T]: ...
|
||||
@overload
|
||||
def side_effect(
|
||||
func: Callable[[list[_T]], object],
|
||||
iterable: Iterable[_T],
|
||||
chunk_size: int,
|
||||
before: Callable[[], object] | None = ...,
|
||||
after: Callable[[], object] | None = ...,
|
||||
) -> Iterator[_T]: ...
|
||||
def sliced(
|
||||
seq: _SupportsSlicing[_T], n: int, strict: bool = ...
|
||||
) -> Iterator[_T]: ...
|
||||
def split_at(
|
||||
iterable: Iterable[_T],
|
||||
pred: Callable[[_T], object],
|
||||
maxsplit: int = ...,
|
||||
keep_separator: bool = ...,
|
||||
) -> Iterator[list[_T]]: ...
|
||||
def split_before(
|
||||
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
|
||||
) -> Iterator[list[_T]]: ...
|
||||
def split_after(
|
||||
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
|
||||
) -> Iterator[list[_T]]: ...
|
||||
def split_when(
|
||||
iterable: Iterable[_T],
|
||||
pred: Callable[[_T, _T], object],
|
||||
maxsplit: int = ...,
|
||||
) -> Iterator[list[_T]]: ...
|
||||
def split_into(
|
||||
iterable: Iterable[_T], sizes: Iterable[int | None]
|
||||
) -> Iterator[list[_T]]: ...
|
||||
@overload
|
||||
def padded(
|
||||
iterable: Iterable[_T],
|
||||
*,
|
||||
n: int | None = ...,
|
||||
next_multiple: bool = ...,
|
||||
) -> Iterator[_T | None]: ...
|
||||
@overload
|
||||
def padded(
|
||||
iterable: Iterable[_T],
|
||||
fillvalue: _U,
|
||||
n: int | None = ...,
|
||||
next_multiple: bool = ...,
|
||||
) -> Iterator[_T | _U]: ...
|
||||
@overload
|
||||
def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
|
||||
@overload
|
||||
def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ...
|
||||
def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
|
||||
@overload
|
||||
def stagger(
|
||||
iterable: Iterable[_T],
|
||||
offsets: _SizedIterable[int] = ...,
|
||||
longest: bool = ...,
|
||||
) -> Iterator[tuple[_T | None, ...]]: ...
|
||||
@overload
|
||||
def stagger(
|
||||
iterable: Iterable[_T],
|
||||
offsets: _SizedIterable[int] = ...,
|
||||
longest: bool = ...,
|
||||
fillvalue: _U = ...,
|
||||
) -> Iterator[tuple[_T | _U, ...]]: ...
|
||||
|
||||
class UnequalIterablesError(ValueError):
|
||||
def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
|
||||
|
||||
@overload
|
||||
def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
|
||||
@overload
|
||||
def zip_equal(
|
||||
__iter1: Iterable[_T1], __iter2: Iterable[_T2]
|
||||
) -> Iterator[tuple[_T1, _T2]]: ...
|
||||
@overload
|
||||
def zip_equal(
|
||||
__iter1: Iterable[_T],
|
||||
__iter2: Iterable[_T],
|
||||
__iter3: Iterable[_T],
|
||||
*iterables: Iterable[_T],
|
||||
) -> Iterator[tuple[_T, ...]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T1],
|
||||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[tuple[_T1 | None]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T1],
|
||||
__iter2: Iterable[_T2],
|
||||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T],
|
||||
__iter2: Iterable[_T],
|
||||
__iter3: Iterable[_T],
|
||||
*iterables: Iterable[_T],
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[tuple[_T | None, ...]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T1],
|
||||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: _U,
|
||||
) -> Iterator[tuple[_T1 | _U]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T1],
|
||||
__iter2: Iterable[_T2],
|
||||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: _U,
|
||||
) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
__iter1: Iterable[_T],
|
||||
__iter2: Iterable[_T],
|
||||
__iter3: Iterable[_T],
|
||||
*iterables: Iterable[_T],
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: _U,
|
||||
) -> Iterator[tuple[_T | _U, ...]]: ...
|
||||
def sort_together(
|
||||
iterables: Iterable[Iterable[_T]],
|
||||
key_list: Iterable[int] = ...,
|
||||
key: Callable[..., Any] | None = ...,
|
||||
reverse: bool = ...,
|
||||
) -> list[tuple[_T, ...]]: ...
|
||||
def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
|
||||
def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
|
||||
def always_iterable(
|
||||
obj: object,
|
||||
base_type: type | tuple[type | tuple[Any, ...], ...] | None = ...,
|
||||
) -> Iterator[Any]: ...
|
||||
def adjacent(
|
||||
predicate: Callable[[_T], bool],
|
||||
iterable: Iterable[_T],
|
||||
distance: int = ...,
|
||||
) -> Iterator[tuple[bool, _T]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: None = None,
|
||||
valuefunc: None = None,
|
||||
reducefunc: None = None,
|
||||
) -> Iterator[tuple[_T, Iterator[_T]]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: None,
|
||||
reducefunc: None,
|
||||
) -> Iterator[tuple[_U, Iterator[_T]]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: None,
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: None,
|
||||
) -> Iterable[tuple[_T, Iterable[_V]]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: None,
|
||||
) -> Iterable[tuple[_U, Iterator[_V]]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: None,
|
||||
valuefunc: None,
|
||||
reducefunc: Callable[[Iterator[_T]], _W],
|
||||
) -> Iterable[tuple[_T, _W]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: None,
|
||||
reducefunc: Callable[[Iterator[_T]], _W],
|
||||
) -> Iterable[tuple[_U, _W]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: None,
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: Callable[[Iterable[_V]], _W],
|
||||
) -> Iterable[tuple[_T, _W]]: ...
|
||||
@overload
|
||||
def groupby_transform(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: Callable[[Iterable[_V]], _W],
|
||||
) -> Iterable[tuple[_U, _W]]: ...
|
||||
|
||||
class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
|
||||
@overload
|
||||
def __init__(self, __stop: _T) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __start: _T, __stop: _T) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
|
||||
def __bool__(self) -> bool: ...
|
||||
def __contains__(self, elem: object) -> bool: ...
|
||||
def __eq__(self, other: object) -> bool: ...
|
||||
@overload
|
||||
def __getitem__(self, key: int) -> _T: ...
|
||||
@overload
|
||||
def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
|
||||
def __hash__(self) -> int: ...
|
||||
def __iter__(self) -> Iterator[_T]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __reduce__(
|
||||
self,
|
||||
) -> tuple[Type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __reversed__(self) -> Iterator[_T]: ...
|
||||
def count(self, value: _T) -> int: ...
|
||||
def index(self, value: _T) -> int: ... # type: ignore
|
||||
|
||||
def count_cycle(
|
||||
iterable: Iterable[_T], n: int | None = ...
|
||||
) -> Iterable[tuple[int, _T]]: ...
|
||||
def mark_ends(
|
||||
iterable: Iterable[_T],
|
||||
) -> Iterable[tuple[bool, bool, _T]]: ...
|
||||
def locate(
|
||||
iterable: Iterable[object],
|
||||
pred: Callable[..., Any] = ...,
|
||||
window_size: int | None = ...,
|
||||
) -> Iterator[int]: ...
|
||||
def lstrip(
|
||||
iterable: Iterable[_T], pred: Callable[[_T], object]
|
||||
) -> Iterator[_T]: ...
|
||||
def rstrip(
|
||||
iterable: Iterable[_T], pred: Callable[[_T], object]
|
||||
) -> Iterator[_T]: ...
|
||||
def strip(
|
||||
iterable: Iterable[_T], pred: Callable[[_T], object]
|
||||
) -> Iterator[_T]: ...
|
||||
|
||||
class islice_extended(Generic[_T], Iterator[_T]):
|
||||
def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ...
|
||||
def __iter__(self) -> islice_extended[_T]: ...
|
||||
def __next__(self) -> _T: ...
|
||||
def __getitem__(self, index: slice) -> islice_extended[_T]: ...
|
||||
|
||||
def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
|
||||
def consecutive_groups(
|
||||
iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
|
||||
) -> Iterator[Iterator[_T]]: ...
|
||||
@overload
|
||||
def difference(
|
||||
iterable: Iterable[_T],
|
||||
func: Callable[[_T, _T], _U] = ...,
|
||||
*,
|
||||
initial: None = ...,
|
||||
) -> Iterator[_T | _U]: ...
|
||||
@overload
|
||||
def difference(
|
||||
iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
|
||||
) -> Iterator[_U]: ...
|
||||
|
||||
class SequenceView(Generic[_T], Sequence[_T]):
|
||||
def __init__(self, target: Sequence[_T]) -> None: ...
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> _T: ...
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> Sequence[_T]: ...
|
||||
def __len__(self) -> int: ...
|
||||
|
||||
class seekable(Generic[_T], Iterator[_T]):
|
||||
def __init__(
|
||||
self, iterable: Iterable[_T], maxlen: int | None = ...
|
||||
) -> None: ...
|
||||
def __iter__(self) -> seekable[_T]: ...
|
||||
def __next__(self) -> _T: ...
|
||||
def __bool__(self) -> bool: ...
|
||||
@overload
|
||||
def peek(self) -> _T: ...
|
||||
@overload
|
||||
def peek(self, default: _U) -> _T | _U: ...
|
||||
def elements(self) -> SequenceView[_T]: ...
|
||||
def seek(self, index: int) -> None: ...
|
||||
|
||||
class run_length:
|
||||
@staticmethod
|
||||
def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ...
|
||||
@staticmethod
|
||||
def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ...
|
||||
|
||||
def exactly_n(
|
||||
iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
|
||||
) -> bool: ...
|
||||
def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ...
|
||||
def make_decorator(
|
||||
wrapping_func: Callable[..., _U], result_index: int = ...
|
||||
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
|
||||
@overload
|
||||
def map_reduce(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: None = ...,
|
||||
reducefunc: None = ...,
|
||||
) -> dict[_U, list[_T]]: ...
|
||||
@overload
|
||||
def map_reduce(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: None = ...,
|
||||
) -> dict[_U, list[_V]]: ...
|
||||
@overload
|
||||
def map_reduce(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: None = ...,
|
||||
reducefunc: Callable[[list[_T]], _W] = ...,
|
||||
) -> dict[_U, _W]: ...
|
||||
@overload
|
||||
def map_reduce(
|
||||
iterable: Iterable[_T],
|
||||
keyfunc: Callable[[_T], _U],
|
||||
valuefunc: Callable[[_T], _V],
|
||||
reducefunc: Callable[[list[_V]], _W],
|
||||
) -> dict[_U, _W]: ...
|
||||
def rlocate(
|
||||
iterable: Iterable[_T],
|
||||
pred: Callable[..., object] = ...,
|
||||
window_size: int | None = ...,
|
||||
) -> Iterator[int]: ...
|
||||
def replace(
|
||||
iterable: Iterable[_T],
|
||||
pred: Callable[..., object],
|
||||
substitutes: Iterable[_U],
|
||||
count: int | None = ...,
|
||||
window_size: int = ...,
|
||||
) -> Iterator[_T | _U]: ...
|
||||
def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
|
||||
def set_partitions(
|
||||
iterable: Iterable[_T], k: int | None = ...
|
||||
) -> Iterator[list[list[_T]]]: ...
|
||||
|
||||
class time_limited(Generic[_T], Iterator[_T]):
|
||||
def __init__(
|
||||
self, limit_seconds: float, iterable: Iterable[_T]
|
||||
) -> None: ...
|
||||
def __iter__(self) -> islice_extended[_T]: ...
|
||||
def __next__(self) -> _T: ...
|
||||
|
||||
@overload
|
||||
def only(
|
||||
iterable: Iterable[_T], *, too_long: _Raisable | None = ...
|
||||
) -> _T | None: ...
|
||||
@overload
|
||||
def only(
|
||||
iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ...
|
||||
) -> _T | _U: ...
|
||||
def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
|
||||
def distinct_combinations(
|
||||
iterable: Iterable[_T], r: int
|
||||
) -> Iterator[tuple[_T, ...]]: ...
|
||||
def filter_except(
|
||||
validator: Callable[[Any], object],
|
||||
iterable: Iterable[_T],
|
||||
*exceptions: Type[BaseException],
|
||||
) -> Iterator[_T]: ...
|
||||
def map_except(
|
||||
function: Callable[[Any], _U],
|
||||
iterable: Iterable[_T],
|
||||
*exceptions: Type[BaseException],
|
||||
) -> Iterator[_U]: ...
|
||||
def map_if(
|
||||
iterable: Iterable[Any],
|
||||
pred: Callable[[Any], bool],
|
||||
func: Callable[[Any], Any],
|
||||
func_else: Callable[[Any], Any] | None = ...,
|
||||
) -> Iterator[Any]: ...
|
||||
def sample(
|
||||
iterable: Iterable[_T],
|
||||
k: int,
|
||||
weights: Iterable[float] | None = ...,
|
||||
) -> list[_T]: ...
|
||||
def is_sorted(
|
||||
iterable: Iterable[_T],
|
||||
key: Callable[[_T], _U] | None = ...,
|
||||
reverse: bool = False,
|
||||
strict: bool = False,
|
||||
) -> bool: ...
|
||||
|
||||
class AbortThread(BaseException):
|
||||
pass
|
||||
|
||||
class callback_iter(Generic[_T], Iterator[_T]):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Any],
|
||||
callback_kwd: str = ...,
|
||||
wait_seconds: float = ...,
|
||||
) -> None: ...
|
||||
def __enter__(self) -> callback_iter[_T]: ...
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> bool | None: ...
|
||||
def __iter__(self) -> callback_iter[_T]: ...
|
||||
def __next__(self) -> _T: ...
|
||||
def _reader(self) -> Iterator[_T]: ...
|
||||
@property
|
||||
def done(self) -> bool: ...
|
||||
@property
|
||||
def result(self) -> Any: ...
|
||||
|
||||
def windowed_complete(
|
||||
iterable: Iterable[_T], n: int
|
||||
) -> Iterator[tuple[_T, ...]]: ...
|
||||
def all_unique(
|
||||
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
||||
) -> bool: ...
|
||||
def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
|
||||
def nth_permutation(
|
||||
iterable: Iterable[_T], r: int, index: int
|
||||
) -> tuple[_T, ...]: ...
|
||||
def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ...
|
||||
def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
|
||||
def combination_index(
|
||||
element: Iterable[_T], iterable: Iterable[_T]
|
||||
) -> int: ...
|
||||
def permutation_index(
|
||||
element: Iterable[_T], iterable: Iterable[_T]
|
||||
) -> int: ...
|
||||
def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
|
||||
|
||||
class countable(Generic[_T], Iterator[_T]):
|
||||
def __init__(self, iterable: Iterable[_T]) -> None: ...
|
||||
def __iter__(self) -> countable[_T]: ...
|
||||
def __next__(self) -> _T: ...
|
||||
|
||||
def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
|
||||
def zip_broadcast(
|
||||
*objects: _T | Iterable[_T],
|
||||
scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ...,
|
||||
strict: bool = ...,
|
||||
) -> Iterable[tuple[_T, ...]]: ...
|
||||
def unique_in_window(
|
||||
iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ...
|
||||
) -> Iterator[_T]: ...
|
||||
def duplicates_everseen(
|
||||
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
||||
) -> Iterator[_T]: ...
|
||||
def duplicates_justseen(
|
||||
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
||||
) -> Iterator[_T]: ...
|
||||
|
||||
class _SupportsLessThan(Protocol):
|
||||
def __lt__(self, __other: Any) -> bool: ...
|
||||
|
||||
_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
|
||||
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
|
||||
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
|
||||
) -> tuple[_T, _T]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: Iterable[_SupportsLessThanT],
|
||||
*,
|
||||
key: None = None,
|
||||
default: _U,
|
||||
) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: Iterable[_T],
|
||||
*,
|
||||
key: Callable[[_T], _SupportsLessThan],
|
||||
default: _U,
|
||||
) -> _U | tuple[_T, _T]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: _SupportsLessThanT,
|
||||
__other: _SupportsLessThanT,
|
||||
*others: _SupportsLessThanT,
|
||||
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: _T,
|
||||
__other: _T,
|
||||
*others: _T,
|
||||
key: Callable[[_T], _SupportsLessThan],
|
||||
) -> tuple[_T, _T]: ...
|
||||
def longest_common_prefix(
|
||||
iterables: Iterable[Iterable[_T]],
|
||||
) -> Iterator[_T]: ...
|
||||
def iequals(*iterables: Iterable[object]) -> bool: ...
|
||||
def constrained_batches(
|
||||
iterable: Iterable[object],
|
||||
max_size: int,
|
||||
max_count: int | None = ...,
|
||||
get_len: Callable[[_T], object] = ...,
|
||||
strict: bool = ...,
|
||||
) -> Iterator[tuple[_T]]: ...
|
||||
def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
||||
0
lib/pkg_resources/_vendor/more_itertools/py.typed
Normal file
0
lib/pkg_resources/_vendor/more_itertools/py.typed
Normal file
930
lib/pkg_resources/_vendor/more_itertools/recipes.py
Normal file
930
lib/pkg_resources/_vendor/more_itertools/recipes.py
Normal file
@@ -0,0 +1,930 @@
|
||||
"""Imported from the recipes section of the itertools documentation.
|
||||
|
||||
All functions taken from the recipes section of the itertools library docs
|
||||
[1]_.
|
||||
Some backward-compatible usability improvements have been made.
|
||||
|
||||
.. [1] http://docs.python.org/library/itertools.html#recipes
|
||||
|
||||
"""
|
||||
import math
|
||||
import operator
|
||||
import warnings
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Sized
|
||||
from functools import reduce
|
||||
from itertools import (
|
||||
chain,
|
||||
combinations,
|
||||
compress,
|
||||
count,
|
||||
cycle,
|
||||
groupby,
|
||||
islice,
|
||||
product,
|
||||
repeat,
|
||||
starmap,
|
||||
tee,
|
||||
zip_longest,
|
||||
)
|
||||
from random import randrange, sample, choice
|
||||
from sys import hexversion
|
||||
|
||||
__all__ = [
|
||||
'all_equal',
|
||||
'batched',
|
||||
'before_and_after',
|
||||
'consume',
|
||||
'convolve',
|
||||
'dotproduct',
|
||||
'first_true',
|
||||
'factor',
|
||||
'flatten',
|
||||
'grouper',
|
||||
'iter_except',
|
||||
'iter_index',
|
||||
'matmul',
|
||||
'ncycles',
|
||||
'nth',
|
||||
'nth_combination',
|
||||
'padnone',
|
||||
'pad_none',
|
||||
'pairwise',
|
||||
'partition',
|
||||
'polynomial_from_roots',
|
||||
'powerset',
|
||||
'prepend',
|
||||
'quantify',
|
||||
'random_combination_with_replacement',
|
||||
'random_combination',
|
||||
'random_permutation',
|
||||
'random_product',
|
||||
'repeatfunc',
|
||||
'roundrobin',
|
||||
'sieve',
|
||||
'sliding_window',
|
||||
'subslices',
|
||||
'tabulate',
|
||||
'tail',
|
||||
'take',
|
||||
'transpose',
|
||||
'triplewise',
|
||||
'unique_everseen',
|
||||
'unique_justseen',
|
||||
]
|
||||
|
||||
_marker = object()
|
||||
|
||||
|
||||
def take(n, iterable):
|
||||
"""Return first *n* items of the iterable as a list.
|
||||
|
||||
>>> take(3, range(10))
|
||||
[0, 1, 2]
|
||||
|
||||
If there are fewer than *n* items in the iterable, all of them are
|
||||
returned.
|
||||
|
||||
>>> take(10, range(3))
|
||||
[0, 1, 2]
|
||||
|
||||
"""
|
||||
return list(islice(iterable, n))
|
||||
|
||||
|
||||
def tabulate(function, start=0):
|
||||
"""Return an iterator over the results of ``func(start)``,
|
||||
``func(start + 1)``, ``func(start + 2)``...
|
||||
|
||||
*func* should be a function that accepts one integer argument.
|
||||
|
||||
If *start* is not specified it defaults to 0. It will be incremented each
|
||||
time the iterator is advanced.
|
||||
|
||||
>>> square = lambda x: x ** 2
|
||||
>>> iterator = tabulate(square, -3)
|
||||
>>> take(4, iterator)
|
||||
[9, 4, 1, 0]
|
||||
|
||||
"""
|
||||
return map(function, count(start))
|
||||
|
||||
|
||||
def tail(n, iterable):
|
||||
"""Return an iterator over the last *n* items of *iterable*.
|
||||
|
||||
>>> t = tail(3, 'ABCDEFG')
|
||||
>>> list(t)
|
||||
['E', 'F', 'G']
|
||||
|
||||
"""
|
||||
# If the given iterable has a length, then we can use islice to get its
|
||||
# final elements. Note that if the iterable is not actually Iterable,
|
||||
# either islice or deque will throw a TypeError. This is why we don't
|
||||
# check if it is Iterable.
|
||||
if isinstance(iterable, Sized):
|
||||
yield from islice(iterable, max(0, len(iterable) - n), None)
|
||||
else:
|
||||
yield from iter(deque(iterable, maxlen=n))
|
||||
|
||||
|
||||
def consume(iterator, n=None):
|
||||
"""Advance *iterable* by *n* steps. If *n* is ``None``, consume it
|
||||
entirely.
|
||||
|
||||
Efficiently exhausts an iterator without returning values. Defaults to
|
||||
consuming the whole iterator, but an optional second argument may be
|
||||
provided to limit consumption.
|
||||
|
||||
>>> i = (x for x in range(10))
|
||||
>>> next(i)
|
||||
0
|
||||
>>> consume(i, 3)
|
||||
>>> next(i)
|
||||
4
|
||||
>>> consume(i)
|
||||
>>> next(i)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
StopIteration
|
||||
|
||||
If the iterator has fewer items remaining than the provided limit, the
|
||||
whole iterator will be consumed.
|
||||
|
||||
>>> i = (x for x in range(3))
|
||||
>>> consume(i, 5)
|
||||
>>> next(i)
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
StopIteration
|
||||
|
||||
"""
|
||||
# Use functions that consume iterators at C speed.
|
||||
if n is None:
|
||||
# feed the entire iterator into a zero-length deque
|
||||
deque(iterator, maxlen=0)
|
||||
else:
|
||||
# advance to the empty slice starting at position n
|
||||
next(islice(iterator, n, n), None)
|
||||
|
||||
|
||||
def nth(iterable, n, default=None):
|
||||
"""Returns the nth item or a default value.
|
||||
|
||||
>>> l = range(10)
|
||||
>>> nth(l, 3)
|
||||
3
|
||||
>>> nth(l, 20, "zebra")
|
||||
'zebra'
|
||||
|
||||
"""
|
||||
return next(islice(iterable, n, None), default)
|
||||
|
||||
|
||||
def all_equal(iterable):
|
||||
"""
|
||||
Returns ``True`` if all the elements are equal to each other.
|
||||
|
||||
>>> all_equal('aaaa')
|
||||
True
|
||||
>>> all_equal('aaab')
|
||||
False
|
||||
|
||||
"""
|
||||
g = groupby(iterable)
|
||||
return next(g, True) and not next(g, False)
|
||||
|
||||
|
||||
def quantify(iterable, pred=bool):
|
||||
"""Return the how many times the predicate is true.
|
||||
|
||||
>>> quantify([True, False, True])
|
||||
2
|
||||
|
||||
"""
|
||||
return sum(map(pred, iterable))
|
||||
|
||||
|
||||
def pad_none(iterable):
|
||||
"""Returns the sequence of elements and then returns ``None`` indefinitely.
|
||||
|
||||
>>> take(5, pad_none(range(3)))
|
||||
[0, 1, 2, None, None]
|
||||
|
||||
Useful for emulating the behavior of the built-in :func:`map` function.
|
||||
|
||||
See also :func:`padded`.
|
||||
|
||||
"""
|
||||
return chain(iterable, repeat(None))
|
||||
|
||||
|
||||
padnone = pad_none
|
||||
|
||||
|
||||
def ncycles(iterable, n):
|
||||
"""Returns the sequence elements *n* times
|
||||
|
||||
>>> list(ncycles(["a", "b"], 3))
|
||||
['a', 'b', 'a', 'b', 'a', 'b']
|
||||
|
||||
"""
|
||||
return chain.from_iterable(repeat(tuple(iterable), n))
|
||||
|
||||
|
||||
def dotproduct(vec1, vec2):
|
||||
"""Returns the dot product of the two iterables.
|
||||
|
||||
>>> dotproduct([10, 10], [20, 20])
|
||||
400
|
||||
|
||||
"""
|
||||
return sum(map(operator.mul, vec1, vec2))
|
||||
|
||||
|
||||
def flatten(listOfLists):
|
||||
"""Return an iterator flattening one level of nesting in a list of lists.
|
||||
|
||||
>>> list(flatten([[0, 1], [2, 3]]))
|
||||
[0, 1, 2, 3]
|
||||
|
||||
See also :func:`collapse`, which can flatten multiple levels of nesting.
|
||||
|
||||
"""
|
||||
return chain.from_iterable(listOfLists)
|
||||
|
||||
|
||||
def repeatfunc(func, times=None, *args):
|
||||
"""Call *func* with *args* repeatedly, returning an iterable over the
|
||||
results.
|
||||
|
||||
If *times* is specified, the iterable will terminate after that many
|
||||
repetitions:
|
||||
|
||||
>>> from operator import add
|
||||
>>> times = 4
|
||||
>>> args = 3, 5
|
||||
>>> list(repeatfunc(add, times, *args))
|
||||
[8, 8, 8, 8]
|
||||
|
||||
If *times* is ``None`` the iterable will not terminate:
|
||||
|
||||
>>> from random import randrange
|
||||
>>> times = None
|
||||
>>> args = 1, 11
|
||||
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
|
||||
[2, 4, 8, 1, 8, 4]
|
||||
|
||||
"""
|
||||
if times is None:
|
||||
return starmap(func, repeat(args))
|
||||
return starmap(func, repeat(args, times))
|
||||
|
||||
|
||||
def _pairwise(iterable):
|
||||
"""Returns an iterator of paired items, overlapping, from the original
|
||||
|
||||
>>> take(4, pairwise(count()))
|
||||
[(0, 1), (1, 2), (2, 3), (3, 4)]
|
||||
|
||||
On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
|
||||
|
||||
"""
|
||||
a, b = tee(iterable)
|
||||
next(b, None)
|
||||
yield from zip(a, b)
|
||||
|
||||
|
||||
try:
|
||||
from itertools import pairwise as itertools_pairwise
|
||||
except ImportError:
|
||||
pairwise = _pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable):
|
||||
yield from itertools_pairwise(iterable)
|
||||
|
||||
pairwise.__doc__ = _pairwise.__doc__
|
||||
|
||||
|
||||
class UnequalIterablesError(ValueError):
|
||||
def __init__(self, details=None):
|
||||
msg = 'Iterables have different lengths'
|
||||
if details is not None:
|
||||
msg += (': index 0 has length {}; index {} has length {}').format(
|
||||
*details
|
||||
)
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
def _zip_equal_generator(iterables):
|
||||
for combo in zip_longest(*iterables, fillvalue=_marker):
|
||||
for val in combo:
|
||||
if val is _marker:
|
||||
raise UnequalIterablesError()
|
||||
yield combo
|
||||
|
||||
|
||||
def _zip_equal(*iterables):
|
||||
# Check whether the iterables are all the same size.
|
||||
try:
|
||||
first_size = len(iterables[0])
|
||||
for i, it in enumerate(iterables[1:], 1):
|
||||
size = len(it)
|
||||
if size != first_size:
|
||||
break
|
||||
else:
|
||||
# If we didn't break out, we can use the built-in zip.
|
||||
return zip(*iterables)
|
||||
|
||||
# If we did break out, there was a mismatch.
|
||||
raise UnequalIterablesError(details=(first_size, i, size))
|
||||
# If any one of the iterables didn't have a length, start reading
|
||||
# them until one runs out.
|
||||
except TypeError:
|
||||
return _zip_equal_generator(iterables)
|
||||
|
||||
|
||||
def grouper(iterable, n, incomplete='fill', fillvalue=None):
|
||||
"""Group elements from *iterable* into fixed-length groups of length *n*.
|
||||
|
||||
>>> list(grouper('ABCDEF', 3))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
||||
|
||||
The keyword arguments *incomplete* and *fillvalue* control what happens for
|
||||
iterables whose length is not a multiple of *n*.
|
||||
|
||||
When *incomplete* is `'fill'`, the last group will contain instances of
|
||||
*fillvalue*.
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
|
||||
|
||||
When *incomplete* is `'ignore'`, the last group will not be emitted.
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
||||
|
||||
When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
|
||||
|
||||
>>> it = grouper('ABCDEFG', 3, incomplete='strict')
|
||||
>>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
UnequalIterablesError
|
||||
|
||||
"""
|
||||
args = [iter(iterable)] * n
|
||||
if incomplete == 'fill':
|
||||
return zip_longest(*args, fillvalue=fillvalue)
|
||||
if incomplete == 'strict':
|
||||
return _zip_equal(*args)
|
||||
if incomplete == 'ignore':
|
||||
return zip(*args)
|
||||
else:
|
||||
raise ValueError('Expected fill, strict, or ignore')
|
||||
|
||||
|
||||
def roundrobin(*iterables):
|
||||
"""Yields an item from each iterable, alternating between them.
|
||||
|
||||
>>> list(roundrobin('ABC', 'D', 'EF'))
|
||||
['A', 'D', 'E', 'B', 'F', 'C']
|
||||
|
||||
This function produces the same output as :func:`interleave_longest`, but
|
||||
may perform better for some inputs (in particular when the number of
|
||||
iterables is small).
|
||||
|
||||
"""
|
||||
# Recipe credited to George Sakkis
|
||||
pending = len(iterables)
|
||||
nexts = cycle(iter(it).__next__ for it in iterables)
|
||||
while pending:
|
||||
try:
|
||||
for next in nexts:
|
||||
yield next()
|
||||
except StopIteration:
|
||||
pending -= 1
|
||||
nexts = cycle(islice(nexts, pending))
|
||||
|
||||
|
||||
def partition(pred, iterable):
|
||||
"""
|
||||
Returns a 2-tuple of iterables derived from the input iterable.
|
||||
The first yields the items that have ``pred(item) == False``.
|
||||
The second yields the items that have ``pred(item) == True``.
|
||||
|
||||
>>> is_odd = lambda x: x % 2 != 0
|
||||
>>> iterable = range(10)
|
||||
>>> even_items, odd_items = partition(is_odd, iterable)
|
||||
>>> list(even_items), list(odd_items)
|
||||
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
|
||||
|
||||
If *pred* is None, :func:`bool` is used.
|
||||
|
||||
>>> iterable = [0, 1, False, True, '', ' ']
|
||||
>>> false_items, true_items = partition(None, iterable)
|
||||
>>> list(false_items), list(true_items)
|
||||
([0, False, ''], [1, True, ' '])
|
||||
|
||||
"""
|
||||
if pred is None:
|
||||
pred = bool
|
||||
|
||||
evaluations = ((pred(x), x) for x in iterable)
|
||||
t1, t2 = tee(evaluations)
|
||||
return (
|
||||
(x for (cond, x) in t1 if not cond),
|
||||
(x for (cond, x) in t2 if cond),
|
||||
)
|
||||
|
||||
|
||||
def powerset(iterable):
|
||||
"""Yields all possible subsets of the iterable.
|
||||
|
||||
>>> list(powerset([1, 2, 3]))
|
||||
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
|
||||
|
||||
:func:`powerset` will operate on iterables that aren't :class:`set`
|
||||
instances, so repeated elements in the input will produce repeated elements
|
||||
in the output. Use :func:`unique_everseen` on the input to avoid generating
|
||||
duplicates:
|
||||
|
||||
>>> seq = [1, 1, 0]
|
||||
>>> list(powerset(seq))
|
||||
[(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
|
||||
>>> from more_itertools import unique_everseen
|
||||
>>> list(powerset(unique_everseen(seq)))
|
||||
[(), (1,), (0,), (1, 0)]
|
||||
|
||||
"""
|
||||
s = list(iterable)
|
||||
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|
||||
|
||||
|
||||
def unique_everseen(iterable, key=None):
|
||||
"""
|
||||
Yield unique elements, preserving order.
|
||||
|
||||
>>> list(unique_everseen('AAAABBBCCDAABBB'))
|
||||
['A', 'B', 'C', 'D']
|
||||
>>> list(unique_everseen('ABBCcAD', str.lower))
|
||||
['A', 'B', 'C', 'D']
|
||||
|
||||
Sequences with a mix of hashable and unhashable items can be used.
|
||||
The function will be slower (i.e., `O(n^2)`) for unhashable items.
|
||||
|
||||
Remember that ``list`` objects are unhashable - you can use the *key*
|
||||
parameter to transform the list to a tuple (which is hashable) to
|
||||
avoid a slowdown.
|
||||
|
||||
>>> iterable = ([1, 2], [2, 3], [1, 2])
|
||||
>>> list(unique_everseen(iterable)) # Slow
|
||||
[[1, 2], [2, 3]]
|
||||
>>> list(unique_everseen(iterable, key=tuple)) # Faster
|
||||
[[1, 2], [2, 3]]
|
||||
|
||||
Similary, you may want to convert unhashable ``set`` objects with
|
||||
``key=frozenset``. For ``dict`` objects,
|
||||
``key=lambda x: frozenset(x.items())`` can be used.
|
||||
|
||||
"""
|
||||
seenset = set()
|
||||
seenset_add = seenset.add
|
||||
seenlist = []
|
||||
seenlist_add = seenlist.append
|
||||
use_key = key is not None
|
||||
|
||||
for element in iterable:
|
||||
k = key(element) if use_key else element
|
||||
try:
|
||||
if k not in seenset:
|
||||
seenset_add(k)
|
||||
yield element
|
||||
except TypeError:
|
||||
if k not in seenlist:
|
||||
seenlist_add(k)
|
||||
yield element
|
||||
|
||||
|
||||
def unique_justseen(iterable, key=None):
|
||||
"""Yields elements in order, ignoring serial duplicates
|
||||
|
||||
>>> list(unique_justseen('AAAABBBCCDAABBB'))
|
||||
['A', 'B', 'C', 'D', 'A', 'B']
|
||||
>>> list(unique_justseen('ABBCcAD', str.lower))
|
||||
['A', 'B', 'C', 'A', 'D']
|
||||
|
||||
"""
|
||||
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
|
||||
|
||||
|
||||
def iter_except(func, exception, first=None):
|
||||
"""Yields results from a function repeatedly until an exception is raised.
|
||||
|
||||
Converts a call-until-exception interface to an iterator interface.
|
||||
Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
|
||||
to end the loop.
|
||||
|
||||
>>> l = [0, 1, 2]
|
||||
>>> list(iter_except(l.pop, IndexError))
|
||||
[2, 1, 0]
|
||||
|
||||
Multiple exceptions can be specified as a stopping condition:
|
||||
|
||||
>>> l = [1, 2, 3, '...', 4, 5, 6]
|
||||
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
||||
[7, 6, 5]
|
||||
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
||||
[4, 3, 2]
|
||||
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
||||
[]
|
||||
|
||||
"""
|
||||
try:
|
||||
if first is not None:
|
||||
yield first()
|
||||
while 1:
|
||||
yield func()
|
||||
except exception:
|
||||
pass
|
||||
|
||||
|
||||
def first_true(iterable, default=None, pred=None):
|
||||
"""
|
||||
Returns the first true value in the iterable.
|
||||
|
||||
If no true value is found, returns *default*
|
||||
|
||||
If *pred* is not None, returns the first item for which
|
||||
``pred(item) == True`` .
|
||||
|
||||
>>> first_true(range(10))
|
||||
1
|
||||
>>> first_true(range(10), pred=lambda x: x > 5)
|
||||
6
|
||||
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
|
||||
'missing'
|
||||
|
||||
"""
|
||||
return next(filter(pred, iterable), default)
|
||||
|
||||
|
||||
def random_product(*args, repeat=1):
|
||||
"""Draw an item at random from each of the input iterables.
|
||||
|
||||
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
|
||||
('c', 3, 'Z')
|
||||
|
||||
If *repeat* is provided as a keyword argument, that many items will be
|
||||
drawn from each iterable.
|
||||
|
||||
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
|
||||
('a', 2, 'd', 3)
|
||||
|
||||
This equivalent to taking a random selection from
|
||||
``itertools.product(*args, **kwarg)``.
|
||||
|
||||
"""
|
||||
pools = [tuple(pool) for pool in args] * repeat
|
||||
return tuple(choice(pool) for pool in pools)
|
||||
|
||||
|
||||
def random_permutation(iterable, r=None):
|
||||
"""Return a random *r* length permutation of the elements in *iterable*.
|
||||
|
||||
If *r* is not specified or is ``None``, then *r* defaults to the length of
|
||||
*iterable*.
|
||||
|
||||
>>> random_permutation(range(5)) # doctest:+SKIP
|
||||
(3, 4, 0, 1, 2)
|
||||
|
||||
This equivalent to taking a random selection from
|
||||
``itertools.permutations(iterable, r)``.
|
||||
|
||||
"""
|
||||
pool = tuple(iterable)
|
||||
r = len(pool) if r is None else r
|
||||
return tuple(sample(pool, r))
|
||||
|
||||
|
||||
def random_combination(iterable, r):
|
||||
"""Return a random *r* length subsequence of the elements in *iterable*.
|
||||
|
||||
>>> random_combination(range(5), 3) # doctest:+SKIP
|
||||
(2, 3, 4)
|
||||
|
||||
This equivalent to taking a random selection from
|
||||
``itertools.combinations(iterable, r)``.
|
||||
|
||||
"""
|
||||
pool = tuple(iterable)
|
||||
n = len(pool)
|
||||
indices = sorted(sample(range(n), r))
|
||||
return tuple(pool[i] for i in indices)
|
||||
|
||||
|
||||
def random_combination_with_replacement(iterable, r):
|
||||
"""Return a random *r* length subsequence of elements in *iterable*,
|
||||
allowing individual elements to be repeated.
|
||||
|
||||
>>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
|
||||
(0, 0, 1, 2, 2)
|
||||
|
||||
This equivalent to taking a random selection from
|
||||
``itertools.combinations_with_replacement(iterable, r)``.
|
||||
|
||||
"""
|
||||
pool = tuple(iterable)
|
||||
n = len(pool)
|
||||
indices = sorted(randrange(n) for i in range(r))
|
||||
return tuple(pool[i] for i in indices)
|
||||
|
||||
|
||||
def nth_combination(iterable, r, index):
|
||||
"""Equivalent to ``list(combinations(iterable, r))[index]``.
|
||||
|
||||
The subsequences of *iterable* that are of length *r* can be ordered
|
||||
lexicographically. :func:`nth_combination` computes the subsequence at
|
||||
sort position *index* directly, without computing the previous
|
||||
subsequences.
|
||||
|
||||
>>> nth_combination(range(5), 3, 5)
|
||||
(0, 3, 4)
|
||||
|
||||
``ValueError`` will be raised If *r* is negative or greater than the length
|
||||
of *iterable*.
|
||||
``IndexError`` will be raised if the given *index* is invalid.
|
||||
"""
|
||||
pool = tuple(iterable)
|
||||
n = len(pool)
|
||||
if (r < 0) or (r > n):
|
||||
raise ValueError
|
||||
|
||||
c = 1
|
||||
k = min(r, n - r)
|
||||
for i in range(1, k + 1):
|
||||
c = c * (n - k + i) // i
|
||||
|
||||
if index < 0:
|
||||
index += c
|
||||
|
||||
if (index < 0) or (index >= c):
|
||||
raise IndexError
|
||||
|
||||
result = []
|
||||
while r:
|
||||
c, n, r = c * r // n, n - 1, r - 1
|
||||
while index >= c:
|
||||
index -= c
|
||||
c, n = c * (n - r) // n, n - 1
|
||||
result.append(pool[-1 - n])
|
||||
|
||||
return tuple(result)
|
||||
|
||||
|
||||
def prepend(value, iterator):
|
||||
"""Yield *value*, followed by the elements in *iterator*.
|
||||
|
||||
>>> value = '0'
|
||||
>>> iterator = ['1', '2', '3']
|
||||
>>> list(prepend(value, iterator))
|
||||
['0', '1', '2', '3']
|
||||
|
||||
To prepend multiple values, see :func:`itertools.chain`
|
||||
or :func:`value_chain`.
|
||||
|
||||
"""
|
||||
return chain([value], iterator)
|
||||
|
||||
|
||||
def convolve(signal, kernel):
|
||||
"""Convolve the iterable *signal* with the iterable *kernel*.
|
||||
|
||||
>>> signal = (1, 2, 3, 4, 5)
|
||||
>>> kernel = [3, 2, 1]
|
||||
>>> list(convolve(signal, kernel))
|
||||
[3, 8, 14, 20, 26, 14, 5]
|
||||
|
||||
Note: the input arguments are not interchangeable, as the *kernel*
|
||||
is immediately consumed and stored.
|
||||
|
||||
"""
|
||||
kernel = tuple(kernel)[::-1]
|
||||
n = len(kernel)
|
||||
window = deque([0], maxlen=n) * n
|
||||
for x in chain(signal, repeat(0, n - 1)):
|
||||
window.append(x)
|
||||
yield sum(map(operator.mul, kernel, window))
|
||||
|
||||
|
||||
def before_and_after(predicate, it):
|
||||
"""A variant of :func:`takewhile` that allows complete access to the
|
||||
remainder of the iterator.
|
||||
|
||||
>>> it = iter('ABCdEfGhI')
|
||||
>>> all_upper, remainder = before_and_after(str.isupper, it)
|
||||
>>> ''.join(all_upper)
|
||||
'ABC'
|
||||
>>> ''.join(remainder) # takewhile() would lose the 'd'
|
||||
'dEfGhI'
|
||||
|
||||
Note that the first iterator must be fully consumed before the second
|
||||
iterator can generate valid results.
|
||||
"""
|
||||
it = iter(it)
|
||||
transition = []
|
||||
|
||||
def true_iterator():
|
||||
for elem in it:
|
||||
if predicate(elem):
|
||||
yield elem
|
||||
else:
|
||||
transition.append(elem)
|
||||
return
|
||||
|
||||
# Note: this is different from itertools recipes to allow nesting
|
||||
# before_and_after remainders into before_and_after again. See tests
|
||||
# for an example.
|
||||
remainder_iterator = chain(transition, it)
|
||||
|
||||
return true_iterator(), remainder_iterator
|
||||
|
||||
|
||||
def triplewise(iterable):
|
||||
"""Return overlapping triplets from *iterable*.
|
||||
|
||||
>>> list(triplewise('ABCDE'))
|
||||
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
|
||||
|
||||
"""
|
||||
for (a, _), (b, c) in pairwise(pairwise(iterable)):
|
||||
yield a, b, c
|
||||
|
||||
|
||||
def sliding_window(iterable, n):
|
||||
"""Return a sliding window of width *n* over *iterable*.
|
||||
|
||||
>>> list(sliding_window(range(6), 4))
|
||||
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
|
||||
|
||||
If *iterable* has fewer than *n* items, then nothing is yielded:
|
||||
|
||||
>>> list(sliding_window(range(3), 4))
|
||||
[]
|
||||
|
||||
For a variant with more features, see :func:`windowed`.
|
||||
"""
|
||||
it = iter(iterable)
|
||||
window = deque(islice(it, n), maxlen=n)
|
||||
if len(window) == n:
|
||||
yield tuple(window)
|
||||
for x in it:
|
||||
window.append(x)
|
||||
yield tuple(window)
|
||||
|
||||
|
||||
def subslices(iterable):
|
||||
"""Return all contiguous non-empty subslices of *iterable*.
|
||||
|
||||
>>> list(subslices('ABC'))
|
||||
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
|
||||
|
||||
This is similar to :func:`substrings`, but emits items in a different
|
||||
order.
|
||||
"""
|
||||
seq = list(iterable)
|
||||
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
|
||||
return map(operator.getitem, repeat(seq), slices)
|
||||
|
||||
|
||||
def polynomial_from_roots(roots):
|
||||
"""Compute a polynomial's coefficients from its roots.
|
||||
|
||||
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
|
||||
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
|
||||
[1, -4, -17, 60]
|
||||
"""
|
||||
# Use math.prod for Python 3.8+,
|
||||
prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1))
|
||||
roots = list(map(operator.neg, roots))
|
||||
return [
|
||||
sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1)
|
||||
]
|
||||
|
||||
|
||||
def iter_index(iterable, value, start=0):
|
||||
"""Yield the index of each place in *iterable* that *value* occurs,
|
||||
beginning with index *start*.
|
||||
|
||||
See :func:`locate` for a more general means of finding the indexes
|
||||
associated with particular values.
|
||||
|
||||
>>> list(iter_index('AABCADEAF', 'A'))
|
||||
[0, 1, 4, 7]
|
||||
"""
|
||||
try:
|
||||
seq_index = iterable.index
|
||||
except AttributeError:
|
||||
# Slow path for general iterables
|
||||
it = islice(iterable, start, None)
|
||||
for i, element in enumerate(it, start):
|
||||
if element is value or element == value:
|
||||
yield i
|
||||
else:
|
||||
# Fast path for sequences
|
||||
i = start - 1
|
||||
try:
|
||||
while True:
|
||||
i = seq_index(value, i + 1)
|
||||
yield i
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def sieve(n):
|
||||
"""Yield the primes less than n.
|
||||
|
||||
>>> list(sieve(30))
|
||||
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
|
||||
"""
|
||||
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
|
||||
data = bytearray((0, 1)) * (n // 2)
|
||||
data[:3] = 0, 0, 0
|
||||
limit = isqrt(n) + 1
|
||||
for p in compress(range(limit), data):
|
||||
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
|
||||
data[2] = 1
|
||||
return iter_index(data, 1) if n > 2 else iter([])
|
||||
|
||||
|
||||
def batched(iterable, n):
|
||||
"""Batch data into lists of length *n*. The last batch may be shorter.
|
||||
|
||||
>>> list(batched('ABCDEFG', 3))
|
||||
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
|
||||
|
||||
This recipe is from the ``itertools`` docs. This library also provides
|
||||
:func:`chunked`, which has a different implementation.
|
||||
"""
|
||||
if hexversion >= 0x30C00A0: # Python 3.12.0a0
|
||||
warnings.warn(
|
||||
(
|
||||
'batched will be removed in a future version of '
|
||||
'more-itertools. Use the standard library '
|
||||
'itertools.batched function instead'
|
||||
),
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(it, n))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
|
||||
def transpose(it):
|
||||
"""Swap the rows and columns of the input.
|
||||
|
||||
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
|
||||
[(1, 11), (2, 22), (3, 33)]
|
||||
|
||||
The caller should ensure that the dimensions of the input are compatible.
|
||||
"""
|
||||
# TODO: when 3.9 goes end-of-life, add stric=True to this.
|
||||
return zip(*it)
|
||||
|
||||
|
||||
def matmul(m1, m2):
|
||||
"""Multiply two matrices.
|
||||
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
|
||||
[[49, 80], [41, 60]]
|
||||
|
||||
The caller should ensure that the dimensions of the input matrices are
|
||||
compatible with each other.
|
||||
"""
|
||||
n = len(m2[0])
|
||||
return batched(starmap(dotproduct, product(m1, transpose(m2))), n)
|
||||
|
||||
|
||||
def factor(n):
|
||||
"""Yield the prime factors of n.
|
||||
>>> list(factor(360))
|
||||
[2, 2, 2, 3, 3, 5]
|
||||
"""
|
||||
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
|
||||
for prime in sieve(isqrt(n) + 1):
|
||||
while True:
|
||||
quotient, remainder = divmod(n, prime)
|
||||
if remainder:
|
||||
break
|
||||
yield prime
|
||||
n = quotient
|
||||
if n == 1:
|
||||
return
|
||||
if n >= 2:
|
||||
yield n
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user