Merge branch 'develop'

This commit is contained in:
rembo10
2024-05-26 09:52:28 +05:30
164 changed files with 19833 additions and 10271 deletions

View File

@@ -1,7 +1,17 @@
# Changelog
## v0.6.2
Released 26 May 2024
Highlights:
* Added soulseek support
* Added bandcamp support
* Changes and dependency updates to work with Python >= 3.12
The full list of commits can be found [here](https://github.com/rembo10/headphones/compare/v0.6.1...v0.6.2).
## v0.6.1
Released 26 November 2023
R eleased 26 November 2023
Highlights:
* Dependency updates to work with > Python 3.11

View File

@@ -310,6 +310,16 @@
<input type="text" name="usenet_retention" value="${config['usenet_retention']}" size="5">
</div>
</fieldset>
<fieldset title="Method for downloading Bandcamp.com files.">
<legend>Bandcamp</legend>
<div class="row">
<label title="Path to folder where Headphones can store raw downloads from Bandcamp.com.">
Bandcamp Directory
</label>
<input type="text" name="bandcamp_dir" value="${config['bandcamp_dir']}" size="50">
<small>Full path where raw MP3s will be stored, e.g. /Users/name/Downloads/bandcamp</small>
</div>
</fieldset>
</td>
<td>
<fieldset title="Method for downloading torrent files.">
@@ -317,7 +327,7 @@
<input type="radio" name="torrent_downloader" id="torrent_downloader_blackhole" value="0" ${config['torrent_downloader_blackhole']}> Black Hole
<input type="radio" name="torrent_downloader" id="torrent_downloader_transmission" value="1" ${config['torrent_downloader_transmission']}> Transmission
<input type="radio" name="torrent_downloader" id="torrent_downloader_utorrent" value="2" ${config['torrent_downloader_utorrent']}> uTorrent (Beta)
<input type="radio" name="torrent_downloader" id="torrent_downloader_deluge" value="3" ${config['torrent_downloader_deluge']}> Deluge (Beta)
<input type="radio" name="torrent_downloader" id="torrent_downloader_deluge" value="3" ${config['torrent_downloader_deluge']}> Deluge
<input type="radio" name="torrent_downloader" id="torrent_downloader_qbittorrent" value="4" ${config['torrent_downloader_qbittorrent']}> QBitTorrent
</fieldset>
<fieldset id="torrent_blackhole_options">
@@ -438,6 +448,11 @@
<input type="text" name="deluge_label" value="${config['deluge_label']}" size="30">
<small>Labels shouldn't contain spaces (requires Label plugin)</small>
</div>
<div class="row">
<label>Download Directory</label>
<input type="text" name="deluge_download_directory" value="${config['deluge_download_directory']}" size="30">
<small>Directory where Deluge should download to</small>
</div>
<div class="row">
<label>Move When Completed</label>
<input type="text" name="deluge_done_directory" value="${config['deluge_done_directory']}" size="30">
@@ -467,7 +482,33 @@
<label>Prefer</label>
<input type="radio" name="prefer_torrents" id="prefer_torrents_0" value="0" ${config['prefer_torrents_0']}>NZBs
<input type="radio" name="prefer_torrents" id="prefer_torrents_1" value="1" ${config['prefer_torrents_1']}>Torrents
<input type="radio" name="prefer_torrents" id="prefer_torrents_2" value="2" ${config['prefer_torrents_2']}>No Preference
<input type="radio" name="prefer_torrents" id="prefer_torrents_2" value="2" ${config['prefer_torrents_2']}>Soulseek
<input type="radio" name="prefer_torrents" id="prefer_torrents_3" value="3" ${config['prefer_torrents_3']}>No Preference
</div>
</fieldset>
</td>
<td>
<fieldset>
<legend>Soulseek</legend>
<div class="row">
<label>Soulseek API URL</label>
<input type="text" name="soulseek_api_url" value="${config['soulseek_api_url']}" size="50">
</div>
<div class="row">
<label>Soulseek API KEY</label>
<input type="text" name="soulseek_api_key" value="${config['soulseek_api_key']}" size="20">
</div>
<div class="row">
<label title="Path to folder where Headphones can find the downloads.">
Soulseek Download Dir:
</label>
<input type="text" name="soulseek_download_dir" value="${config['soulseek_download_dir']}" size="50">
</div>
<div class="row">
<label title="Path to folder where Headphones can find the downloads.">
Soulseek Incomplete Download Dir:
</label>
<input type="text" name="soulseek_incomplete_download_dir" value="${config['soulseek_incomplete_download_dir']}" size="50">
</div>
</fieldset>
</td>
@@ -579,6 +620,19 @@
</div>
</div>
</fieldset>
<fieldset>
<legend>Other</legend>
<fieldset>
<div class="row checkbox left">
<input id="use_bandcamp" type="checkbox" class="bigcheck" name="use_bandcamp" value="1" ${config['use_bandcamp']} /><label for="use_bandcamp"><span class="option">Bandcamp</span></label>
</div>
</fieldset>
<fieldset>
<div class="row checkbox left">
<input id="use_soulseek" type="checkbox" class="bigcheck" name="use_soulseek" value="1" ${config['use_soulseek']} /><label for="use_soulseek"><span class="option">Soulseek</span></label>
</div>
</fieldset>
</fieldset>
</td>
<td>
<fieldset>

View File

@@ -56,6 +56,8 @@
fileid = 'torrent'
if item['URL'].find('codeshy') != -1:
fileid = 'nzb'
if item['URL'].find('bandcamp') != -1:
fileid = 'bandcamp'
folder = 'Folder: ' + item['FolderName']

166
headphones/bandcamp.py Normal file
View File

@@ -0,0 +1,166 @@
# This file is part of Headphones.
#
# Headphones is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Headphones is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Headphones. If not, see <http://www.gnu.org/licenses/>
import headphones
import json
import os
import re
from headphones import logger, helpers, metadata, request
from headphones.common import USER_AGENT
from headphones.types import Result
from mediafile import MediaFile, UnreadableFileError
from bs4 import BeautifulSoup
from bs4 import FeatureNotFound
def search(album, albumlength=None, page=1, resultlist=None):
dic = {'...': '', ' & ': ' ', ' = ': ' ', '?': '', '$': 's', ' + ': ' ',
'"': '', ',': '', '*': '', '.': '', ':': ''}
if resultlist is None:
resultlist = []
cleanalbum = helpers.latinToAscii(
helpers.replace_all(album['AlbumTitle'], dic)
).strip()
cleanartist = helpers.latinToAscii(
helpers.replace_all(album['ArtistName'], dic)
).strip()
headers = {'User-Agent': USER_AGENT}
params = {
"page": page,
"q": cleanalbum,
}
logger.info("Looking up https://bandcamp.com/search with {}".format(
params))
content = request.request_content(
url='https://bandcamp.com/search',
params=params,
headers=headers
).decode('utf8')
try:
soup = BeautifulSoup(content, "html5lib")
except FeatureNotFound:
soup = BeautifulSoup(content, "html.parser")
for item in soup.find_all("li", class_="searchresult"):
type = item.find('div', class_='itemtype').text.strip().lower()
if type == "album":
data = parse_album(item)
cleanartist_found = helpers.latinToAscii(data['artist'])
cleanalbum_found = helpers.latinToAscii(data['album'])
logger.debug(u"{} - {}".format(data['album'], cleanalbum_found))
logger.debug("Comparing {} to {}".format(
cleanalbum, cleanalbum_found))
if (cleanartist.lower() == cleanartist_found.lower() and
cleanalbum.lower() == cleanalbum_found.lower()):
resultlist.append(Result(
data['title'], data['size'], data['url'],
'bandcamp', 'bandcamp', True))
else:
continue
if(soup.find('a', class_='next')):
page += 1
logger.debug("Calling next page ({})".format(page))
search(album, albumlength=albumlength,
page=page, resultlist=resultlist)
return resultlist
def download(album, bestqual):
html = request.request_content(url=bestqual.url).decode('utf-8')
trackinfo = []
try:
trackinfo = json.loads(
re.search(r"trackinfo&quot;:(\[.*?\]),", html)
.group(1)
.replace('&quot;', '"'))
except ValueError as e:
logger.warn("Couldn't load json: {}".format(e))
directory = os.path.join(
headphones.CONFIG.BANDCAMP_DIR,
u'{} - {}'.format(
album['ArtistName'].replace('/', '_'),
album['AlbumTitle'].replace('/', '_')))
directory = helpers.latinToAscii(directory)
if not os.path.exists(directory):
try:
os.makedirs(directory)
except Exception as e:
logger.warn("Could not create directory ({})".format(e))
index = 1
for track in trackinfo:
filename = helpers.replace_illegal_chars(
u'{:02d} - {}.mp3'.format(index, track['title']))
fullname = os.path.join(directory.encode('utf-8'),
filename.encode('utf-8'))
logger.debug("Downloading to {}".format(fullname))
if 'file' in track and track['file'] != None and 'mp3-128' in track['file']:
content = request.request_content(track['file']['mp3-128'])
open(fullname, 'wb').write(content)
try:
f = MediaFile(fullname)
date, year = metadata._date_year(album)
f.update({
'artist': album['ArtistName'].encode('utf-8'),
'album': album['AlbumTitle'].encode('utf-8'),
'title': track['title'].encode('utf-8'),
'track': track['track_num'],
'tracktotal': len(trackinfo),
'year': year,
})
f.save()
except UnreadableFileError as ex:
logger.warn("MediaFile couldn't parse: %s (%s)",
fullname,
str(ex))
index += 1
return directory
def parse_album(item):
album = item.find('div', class_='heading').text.strip()
artist = item.find('div', class_='subhead').text.strip().replace("by ", "")
released = item.find('div', class_='released').text.strip().replace(
"released ", "")
year = re.search(r"(\d{4})", released).group(1)
url = item.find('div', class_='heading').find('a')['href'].split("?")[0]
length = item.find('div', class_='length').text.strip()
tracks, minutes = length.split(",")
tracks = tracks.replace(" tracks", "").replace(" track", "").strip()
minutes = minutes.replace(" minutes", "").strip()
# bandcamp offers mp3 128b with should be 960KB/minute
size = int(minutes) * 983040
data = {"title": u'{} - {} [{}]'.format(artist, album, year),
"artist": artist, "album": album,
"url": url, "size": size}
return data

View File

@@ -102,36 +102,6 @@ class Quality:
return (anyQualities, bestQualities)
@staticmethod
def nameQuality(name):
def checkName(list, func):
return func([re.search(x, name, re.I) for x in list])
name = os.path.basename(name)
# if we have our exact text then assume we put it there
for x in Quality.qualityStrings:
if x == Quality.UNKNOWN:
continue
regex = '\W' + Quality.qualityStrings[x].replace(' ', '\W') + '\W'
regex_match = re.search(regex, name, re.I)
if regex_match:
return x
# TODO: fix quality checking here
if checkName(["mp3", "192"], any) and not checkName(["flac"], all):
return Quality.B192
elif checkName(["mp3", "256"], any) and not checkName(["flac"], all):
return Quality.B256
elif checkName(["mp3", "vbr"], any) and not checkName(["flac"], all):
return Quality.VBR
elif checkName(["mp3", "320"], any) and not checkName(["flac"], all):
return Quality.B320
else:
return Quality.UNKNOWN
@staticmethod
def assumeQuality(name):
if name.lower().endswith(".mp3"):
@@ -158,13 +128,6 @@ class Quality:
return (Quality.NONE, status)
@staticmethod
def statusFromName(name, assume=True):
quality = Quality.nameQuality(name)
if assume and quality == Quality.UNKNOWN:
quality = Quality.assumeQuality(name)
return Quality.compositeStatus(DOWNLOADED, quality)
DOWNLOADED = None
SNATCHED = None
SNATCHED_PROPER = None

View File

@@ -80,6 +80,7 @@ _CONFIG_DEFINITIONS = {
'DELUGE_PASSWORD': (str, 'Deluge', ''),
'DELUGE_LABEL': (str, 'Deluge', ''),
'DELUGE_DONE_DIRECTORY': (str, 'Deluge', ''),
'DELUGE_DOWNLOAD_DIRECTORY': (str, 'Deluge', ''),
'DELUGE_PAUSED': (int, 'Deluge', 0),
'DESTINATION_DIR': (str, 'General', ''),
'DETECT_BITRATE': (int, 'General', 0),
@@ -269,6 +270,11 @@ _CONFIG_DEFINITIONS = {
'SONGKICK_ENABLED': (int, 'Songkick', 1),
'SONGKICK_FILTER_ENABLED': (int, 'Songkick', 0),
'SONGKICK_LOCATION': (str, 'Songkick', ''),
'SOULSEEK_API_URL': (str, 'Soulseek', ''),
'SOULSEEK_API_KEY': (str, 'Soulseek', ''),
'SOULSEEK_DOWNLOAD_DIR': (str, 'Soulseek', ''),
'SOULSEEK_INCOMPLETE_DOWNLOAD_DIR': (str, 'Soulseek', ''),
'SOULSEEK': (int, 'Soulseek', 0),
'SUBSONIC_ENABLED': (int, 'Subsonic', 0),
'SUBSONIC_HOST': (str, 'Subsonic', ''),
'SUBSONIC_PASSWORD': (str, 'Subsonic', ''),
@@ -317,7 +323,9 @@ _CONFIG_DEFINITIONS = {
'XBMC_PASSWORD': (str, 'XBMC', ''),
'XBMC_UPDATE': (int, 'XBMC', 0),
'XBMC_USERNAME': (str, 'XBMC', ''),
'XLDPROFILE': (str, 'General', '')
'XLDPROFILE': (str, 'General', ''),
'BANDCAMP': (int, 'General', 0),
'BANDCAMP_DIR': (path, 'General', '')
}

View File

@@ -58,12 +58,12 @@ def _scrubber(text):
if scrub_logs:
try:
# URL parameter values
text = re.sub('=[0-9a-zA-Z]*', '=REMOVED', text)
text = re.sub(r'=[0-9a-zA-Z]*', r'=REMOVED', text)
# Local host with port
# text = re.sub('\:\/\/.*\:', '://REMOVED:', text) # just host
text = re.sub('\:\/\/.*\:[0-9]*', '://REMOVED:', text)
text = re.sub(r'\:\/\/.*\:[0-9]*', r'://REMOVED:', text)
# Session cookie
text = re.sub("_session_id'\: '.*'", "_session_id': 'REMOVED'", text)
text = re.sub(r"_session_id'\: '.*'", r"_session_id': 'REMOVED'", text)
# Local Windows user path
if text.lower().startswith('c:\\users\\'):
k = text.split('\\')
@@ -128,9 +128,9 @@ def addTorrent(link, data=None, name=None):
# Extract torrent name from .torrent
try:
logger.debug('Deluge: Getting torrent name length')
name_length = int(re.findall('name([0-9]*)\:.*?\:', str(torrentfile))[0])
name_length = int(re.findall(r'name([0-9]*)\:.*?\:', str(torrentfile))[0])
logger.debug('Deluge: Getting torrent name')
name = re.findall('name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
name = re.findall(r'name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
except Exception as e:
logger.debug('Deluge: Could not get torrent name, getting file name')
# get last part of link/path (name only)
@@ -160,9 +160,9 @@ def addTorrent(link, data=None, name=None):
# Extract torrent name from .torrent
try:
logger.debug('Deluge: Getting torrent name length')
name_length = int(re.findall('name([0-9]*)\:.*?\:', str(torrentfile))[0])
name_length = int(re.findall(r'name([0-9]*)\:.*?\:', str(torrentfile))[0])
logger.debug('Deluge: Getting torrent name')
name = re.findall('name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
name = re.findall(r'name[0-9]*\:(.*?)\:', str(torrentfile))[0][:name_length]
except Exception as e:
logger.debug('Deluge: Could not get torrent name, getting file name')
# get last part of link/path (name only)
@@ -466,19 +466,56 @@ def _add_torrent_url(result):
def _add_torrent_file(result):
logger.debug('Deluge: Adding file')
options = {}
if headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY:
options['download_location'] = headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
options['move_completed'] = 1
if headphones.CONFIG.DELUGE_DONE_DIRECTORY:
options['move_completed_path'] = headphones.CONFIG.DELUGE_DONE_DIRECTORY
else:
options['move_completed_path'] = headphones.CONFIG.DOWNLOAD_TORRENT_DIR
if headphones.CONFIG.DELUGE_PAUSED:
options['add_paused'] = headphones.CONFIG.DELUGE_PAUSED
if not any(delugeweb_auth):
_get_auth()
try:
# content is torrent file contents that needs to be encoded to base64
post_data = json.dumps({"method": "core.add_torrent_file",
"params": [result['name'] + '.torrent',
b64encode(result['content']).decode(), {}],
b64encode(result['content'].encode('utf8')),
options],
"id": 2})
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
verify=deluge_verify_cert, headers=headers)
result['hash'] = json.loads(response.text)['result']
logger.debug('Deluge: Response was %s' % str(json.loads(response.text)))
return json.loads(response.text)['result']
except UnicodeDecodeError:
try:
# content is torrent file contents that needs to be encoded to base64
# this time let's try leaving the encoding as is
logger.debug('Deluge: There was a decoding issue, let\'s try again')
post_data = json.dumps({"method": "core.add_torrent_file",
"params": [result['name'].decode('utf8') + '.torrent',
b64encode(result['content']),
options],
"id": 22})
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
verify=deluge_verify_cert, headers=headers)
result['hash'] = json.loads(response.text)['result']
logger.debug('Deluge: Response was %s' % str(json.loads(response.text)))
return json.loads(response.text)['result']
except Exception as e:
logger.error('Deluge: Adding torrent file failed after decode: %s' % str(e))
formatted_lines = traceback.format_exc().splitlines()
logger.error('; '.join(formatted_lines))
return False
except Exception as e:
logger.error('Deluge: Adding torrent file failed: %s' % str(e))
formatted_lines = traceback.format_exc().splitlines()
@@ -566,61 +603,3 @@ def setSeedRatio(result):
return None
def setTorrentPath(result):
logger.debug('Deluge: Setting download path')
if not any(delugeweb_auth):
_get_auth()
try:
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
post_data = json.dumps({"method": "core.set_torrent_move_completed",
"params": [result['hash'], True],
"id": 7})
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
verify=deluge_verify_cert, headers=headers)
if headphones.CONFIG.DELUGE_DONE_DIRECTORY:
move_to = headphones.CONFIG.DELUGE_DONE_DIRECTORY
else:
move_to = headphones.CONFIG.DOWNLOAD_TORRENT_DIR
if not os.path.exists(move_to):
logger.debug('Deluge: %s directory doesn\'t exist, let\'s create it' % move_to)
os.makedirs(move_to)
post_data = json.dumps({"method": "core.set_torrent_move_completed_path",
"params": [result['hash'], move_to],
"id": 8})
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
verify=deluge_verify_cert, headers=headers)
return not json.loads(response.text)['error']
return True
except Exception as e:
logger.error('Deluge: Setting torrent move-to directory failed: %s' % str(e))
formatted_lines = traceback.format_exc().splitlines()
logger.error('; '.join(formatted_lines))
return None
def setTorrentPause(result):
logger.debug('Deluge: Pausing torrent')
if not any(delugeweb_auth):
_get_auth()
try:
if headphones.CONFIG.DELUGE_PAUSED:
post_data = json.dumps({"method": "core.pause_torrent",
"params": [[result['hash']]],
"id": 9})
response = requests.post(delugeweb_url, data=post_data.encode('utf-8'), cookies=delugeweb_auth,
verify=deluge_verify_cert, headers=headers)
return not json.loads(response.text)['error']
return True
except Exception as e:
logger.error('Deluge: Setting torrent paused failed: %s' % str(e))
formatted_lines = traceback.format_exc().splitlines()
logger.error('; '.join(formatted_lines))
return None

View File

@@ -184,7 +184,7 @@ def bytes_to_mb(bytes):
def mb_to_bytes(mb_str):
result = re.search('^(\d+(?:\.\d+)?)\s?(?:mb)?', mb_str, flags=re.I)
result = re.search(r"^(\d+(?:\.\d+)?)\s?(?:mb)?", mb_str, flags=re.I)
if result:
return int(float(result.group(1)) * 1048576)
@@ -253,9 +253,9 @@ def replace_all(text, dic):
def replace_illegal_chars(string, type="file"):
if type == "file":
string = re.sub('[\?"*:|<>/]', '_', string)
string = re.sub(r"[\?\"*:|<>/]", "_", string)
if type == "folder":
string = re.sub('[:\?<>"|*]', '_', string)
string = re.sub(r"[:\?<>\"|*]", "_", string)
return string
@@ -386,7 +386,7 @@ def clean_musicbrainz_name(s, return_as_string=True):
def cleanTitle(title):
title = re.sub('[\.\-\/\_]', ' ', title).lower()
title = re.sub(r"[\.\-\/\_]", " ", title).lower()
# Strip out extra whitespace
title = ' '.join(title.split())
@@ -1050,3 +1050,10 @@ def have_pct_have_total(db_artist):
have_pct = have_tracks / total_tracks if total_tracks else 0
return (have_pct, total_tracks)
def has_token(title, token):
return bool(
re.search(rf'(?:\W|^)+{token}(?:\W|$)+',
title,
re.IGNORECASE | re.UNICODE)
)

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from .unittestcompat import TestCase
from headphones.helpers import clean_name, is_valid_date, age
from headphones.helpers import clean_name, is_valid_date, age, has_token
class HelpersTest(TestCase):
@@ -56,3 +56,18 @@ class HelpersTest(TestCase):
]
for input, expected, desc in test_cases:
self.assertEqual(is_valid_date(input), expected, desc)
def test_has_token(self):
"""helpers: has_token()"""
self.assertEqual(
has_token("a cat ran", "cat"),
True,
"return True if token is in string"
)
self.assertEqual(
has_token("acatran", "cat"),
False,
"return False if token is part of another word"
)

View File

@@ -27,7 +27,7 @@ from beets import config as beetsconfig
from beets import logging as beetslogging
from mediafile import MediaFile, FileTypeError, UnreadableFileError
from beetsplug import lyrics as beetslyrics
from headphones import notifiers, utorrent, transmission, deluge, qbittorrent
from headphones import notifiers, utorrent, transmission, deluge, qbittorrent, soulseek
from headphones import db, albumart, librarysync
from headphones import logger, helpers, mb, music_encoder
from headphones import metadata
@@ -36,18 +36,44 @@ postprocessor_lock = threading.Lock()
def checkFolder():
logger.debug("Checking download folder for completed downloads (only snatched ones).")
logger.info("Checking download folder for completed downloads (only snatched ones).")
with postprocessor_lock:
myDB = db.DBConnection()
snatched = myDB.select('SELECT * from snatched WHERE Status="Snatched"')
# If soulseek is used, this part will get the status from the soulseek api and return completed and errored albums
completed_albums, errored_albums = set(), set()
if any(album['Kind'] == 'soulseek' for album in snatched):
completed_albums, errored_albums = soulseek.download_completed()
for album in snatched:
if album['FolderName']:
folder_name = album['FolderName']
single = False
if album['Kind'] == 'nzb':
download_dir = headphones.CONFIG.DOWNLOAD_DIR
if album['Kind'] == 'soulseek':
if folder_name in errored_albums:
# If the album had any tracks with errors in it, the whole download is considered faulty. Status will be reset to wanted.
logger.info(f"Album with folder '{folder_name}' had errors during download. Setting status to 'Wanted'.")
myDB.action('UPDATE albums SET Status="Wanted" WHERE AlbumID=? AND Status="Snatched"', (album['AlbumID'],))
# Folder will be removed from configured complete and Incomplete directory
complete_path = os.path.join(headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR, folder_name)
incomplete_path = os.path.join(headphones.CONFIG.SOULSEEK_INCOMPLETE_DOWNLOAD_DIR, folder_name)
for path in [complete_path, incomplete_path]:
try:
shutil.rmtree(path)
except Exception as e:
pass
continue
elif folder_name in completed_albums:
download_dir = headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR
else:
continue
elif album['Kind'] == 'nzb':
download_dir = headphones.CONFIG.DOWNLOAD_DIR
elif album['Kind'] == 'bandcamp':
download_dir = headphones.CONFIG.BANDCAMP_DIR
else:
if headphones.CONFIG.DELUGE_DONE_DIRECTORY and headphones.CONFIG.TORRENT_DOWNLOADER == 3:
download_dir = headphones.CONFIG.DELUGE_DONE_DIRECTORY
@@ -289,7 +315,7 @@ def verify(albumid, albumpath, Kind=None, forced=False, keep_original_folder=Fal
logger.debug('Metadata check failed. Verifying filenames...')
for downloaded_track in downloaded_track_list:
track_name = os.path.splitext(downloaded_track)[0]
split_track_name = re.sub('[\.\-\_]', ' ', track_name).lower()
split_track_name = re.sub(r'[\.\-\_]', r' ', track_name).lower()
for track in tracks:
if not track['TrackTitle']:
@@ -1170,8 +1196,14 @@ def forcePostProcess(dir=None, expand_subfolders=True, album_dir=None, keep_orig
download_dirs.append(dir)
if headphones.CONFIG.DOWNLOAD_DIR and not dir:
download_dirs.append(headphones.CONFIG.DOWNLOAD_DIR)
if headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR and not dir:
download_dirs.append(headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR)
if headphones.CONFIG.DOWNLOAD_TORRENT_DIR and not dir:
download_dirs.append(headphones.CONFIG.DOWNLOAD_TORRENT_DIR)
download_dirs.append(
headphones.CONFIG.DOWNLOAD_TORRENT_DIR.encode(headphones.SYS_ENCODING, 'replace'))
if headphones.CONFIG.BANDCAMP and not dir:
download_dirs.append(
headphones.CONFIG.BANDCAMP_DIR.encode(headphones.SYS_ENCODING, 'replace'))
# If DOWNLOAD_DIR and DOWNLOAD_TORRENT_DIR are the same, remove the duplicate to prevent us from trying to process the same folder twice.
download_dirs = list(set(download_dirs))

View File

@@ -42,19 +42,22 @@ class Rutracker(object):
'login_password': headphones.CONFIG.RUTRACKER_PASSWORD,
'login': b'\xc2\xf5\xee\xe4' # '%C2%F5%EE%E4'
}
headers = {
'User-Agent' : 'Headphones'
}
logger.info("Attempting to log in to rutracker...")
try:
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False)
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers)
# try again
if not self.has_bb_session_cookie(r):
time.sleep(10)
if headphones.CONFIG.RUTRACKER_COOKIE:
logger.info("Attempting to log in using predefined cookie...")
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, cookies={'bb_session': headphones.CONFIG.RUTRACKER_COOKIE})
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers, cookies={'bb_session': headphones.CONFIG.RUTRACKER_COOKIE})
else:
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False)
r = self.session.post(loginpage, data=post_params, timeout=self.timeout, allow_redirects=False, headers=headers)
if self.has_bb_session_cookie(r):
self.loggedin = True
logger.info("Successfully logged in to rutracker")
@@ -113,7 +116,10 @@ class Rutracker(object):
Parse the search results and return valid torrent list
"""
try:
headers = {'Referer': self.search_referer}
headers = {
'Referer': self.search_referer,
'User-Agent' : 'Headphones'
}
r = self.session.get(url=searchurl, headers=headers, timeout=self.timeout)
soup = BeautifulSoup(r.content, 'html.parser')
@@ -183,7 +189,10 @@ class Rutracker(object):
downloadurl = 'https://rutracker.org/forum/dl.php?t=' + torrent_id
cookie = {'bb_dl': torrent_id}
try:
headers = {'Referer': url}
headers = {
'Referer': url,
'User-Agent' : 'Headphones'
}
r = self.session.post(url=downloadurl, cookies=cookie, headers=headers,
timeout=self.timeout)
return r.content

View File

@@ -37,10 +37,29 @@ from unidecode import unidecode
import headphones
from headphones.common import USER_AGENT
from headphones.helpers import (
bytes_to_mb,
has_token,
piratesize,
replace_all,
replace_illegal_chars,
sab_replace_dots,
sab_replace_spaces,
sab_sanitize_foldername,
split_string
)
from headphones.types import Result
from headphones import logger, db, helpers, classes, sab, nzbget, request
from headphones import utorrent, transmission, notifiers, rutracker, deluge, qbittorrent
from headphones import logger, db, classes, sab, nzbget, request
from headphones import (
bandcamp,
deluge,
notifiers,
qbittorrent,
rutracker,
soulseek,
transmission,
utorrent,
)
# Magnet to torrent services, for Black hole. Stolen from CouchPotato.
TORRENT_TO_MAGNET_SERVICES = [
@@ -137,7 +156,7 @@ def calculate_torrent_hash(link, data=None):
"""
if link.startswith("magnet:"):
torrent_hash = re.findall("urn:btih:([\w]{32,40})", link)[0]
torrent_hash = re.findall(r"urn:btih:([\w]{32,40})", link)[0]
if len(torrent_hash) == 32:
torrent_hash = b16encode(b32decode(torrent_hash)).lower()
elif data:
@@ -261,6 +280,8 @@ def strptime_musicbrainz(date_str):
def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
NZB_PROVIDERS = (headphones.CONFIG.HEADPHONES_INDEXER or
headphones.CONFIG.NEWZNAB or
headphones.CONFIG.NZBSORG or
@@ -284,25 +305,33 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
[album['AlbumID']])[0][0]
if headphones.CONFIG.PREFER_TORRENTS == 0 and not choose_specific_download:
if NZB_PROVIDERS and NZB_DOWNLOADERS:
results = searchNZB(album, new, losslessOnly, albumlength)
if not results and TORRENT_PROVIDERS:
results = searchTorrent(album, new, losslessOnly, albumlength)
elif headphones.CONFIG.PREFER_TORRENTS == 1 and not choose_specific_download:
if not results and headphones.CONFIG.BANDCAMP:
results = searchBandcamp(album, new, albumlength)
elif headphones.CONFIG.PREFER_TORRENTS == 1 and not choose_specific_download:
if TORRENT_PROVIDERS:
results = searchTorrent(album, new, losslessOnly, albumlength)
if not results and NZB_PROVIDERS and NZB_DOWNLOADERS:
results = searchNZB(album, new, losslessOnly, albumlength)
if not results and headphones.CONFIG.BANDCAMP:
results = searchBandcamp(album, new, albumlength)
elif headphones.CONFIG.PREFER_TORRENTS == 2 and not choose_specific_download:
results = searchSoulseek(album, new, losslessOnly, albumlength)
else:
nzb_results = None
torrent_results = None
bandcamp_results = None
if NZB_PROVIDERS and NZB_DOWNLOADERS:
nzb_results = searchNZB(album, new, losslessOnly, albumlength, choose_specific_download)
@@ -311,13 +340,16 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
torrent_results = searchTorrent(album, new, losslessOnly, albumlength,
choose_specific_download)
if headphones.CONFIG.BANDCAMP:
bandcamp_results = searchBandcamp(album, new, albumlength)
if not nzb_results:
nzb_results = []
if not torrent_results:
torrent_results = []
results = nzb_results + torrent_results
results = nzb_results + torrent_results + bandcamp_results
if choose_specific_download:
return results
@@ -338,6 +370,7 @@ def do_sorted_search(album, new, losslessOnly, choose_specific_download=False):
(data, result) = preprocess(sorted_search_results)
if data and result:
#print(f'going to send stuff to downloader. data: {data}, album: {album}')
send_to_downloader(data, result, album)
@@ -360,7 +393,7 @@ def more_filtering(results, album, albumlength, new):
logger.debug('Target bitrate: %s kbps' % headphones.CONFIG.PREFERRED_BITRATE)
if albumlength:
targetsize = albumlength / 1000 * int(headphones.CONFIG.PREFERRED_BITRATE) * 128
logger.info('Target size: %s' % helpers.bytes_to_mb(targetsize))
logger.info('Target size: %s' % bytes_to_mb(targetsize))
if headphones.CONFIG.PREFERRED_BITRATE_LOW_BUFFER:
low_size_limit = targetsize * int(
headphones.CONFIG.PREFERRED_BITRATE_LOW_BUFFER) / 100
@@ -377,14 +410,14 @@ def more_filtering(results, album, albumlength, new):
if low_size_limit and result.size < low_size_limit:
logger.info(
f"{result.title} from {result.provider} is too small for this album. "
f"(Size: {result.size}, MinSize: {helpers.bytes_to_mb(low_size_limit)})"
f"(Size: {result.size}, MinSize: {bytes_to_mb(low_size_limit)})"
)
continue
if high_size_limit and result.size > high_size_limit:
logger.info(
f"{result.title} from {result.provider} is too large for this album. "
f"(Size: {result.size}, MaxSize: {helpers.bytes_to_mb(high_size_limit)})"
f"(Size: {result.size}, MaxSize: {bytes_to_mb(high_size_limit)})"
)
# Keep lossless results if there are no good lossy matches
if not (allow_lossless and 'flac' in result.title.lower()):
@@ -424,7 +457,7 @@ def sort_search_results(resultlist, album, new, albumlength):
# Add a priority if it has any of the preferred words
results_with_priority = []
preferred_words = helpers.split_string(headphones.CONFIG.PREFERRED_WORDS)
preferred_words = split_string(headphones.CONFIG.PREFERRED_WORDS)
for result in resultlist:
priority = 0
for word in preferred_words:
@@ -502,6 +535,10 @@ def get_year_from_release_date(release_date):
return year
def searchBandcamp(album, new=False, albumlength=None):
return bandcamp.search(album)
def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
choose_specific_download=False):
reldate = album['ReleaseDate']
@@ -521,8 +558,8 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
':': ''
}
cleanalbum = unidecode(helpers.replace_all(album['AlbumTitle'], replacements)).strip()
cleanartist = unidecode(helpers.replace_all(album['ArtistName'], replacements)).strip()
cleanalbum = unidecode(replace_all(album['AlbumTitle'], replacements)).strip()
cleanartist = unidecode(replace_all(album['ArtistName'], replacements)).strip()
# Use the provided search term if available, otherwise build a search term
if album['SearchTerm']:
@@ -542,8 +579,8 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
term = cleanartist + ' ' + cleanalbum
# Replace bad characters in the term
term = re.sub('[\.\-\/]', ' ', term)
artistterm = re.sub('[\.\-\/]', ' ', cleanartist)
term = re.sub(r'[\.\-\/]', r' ', term)
artistterm = re.sub(r'[\.\-\/]', r' ', cleanartist)
# If Preferred Bitrate and High Limit and Allow Lossless then get both lossy and lossless
if headphones.CONFIG.PREFERRED_QUALITY == 2 and headphones.CONFIG.PREFERRED_BITRATE and headphones.CONFIG.PREFERRED_BITRATE_HIGH_BUFFER and headphones.CONFIG.PREFERRED_BITRATE_ALLOW_LOSSLESS:
@@ -599,7 +636,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
size = int(item.links[1]['length'])
resultlist.append(Result(title, size, url, provider, 'nzb', True))
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
except Exception as e:
logger.error("An unknown error occurred trying to parse the feed: %s" % e)
@@ -670,7 +707,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
size = int(item.links[1]['length'])
if all(word.lower() in title.lower() for word in term.split()):
logger.info(
'Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
'Found %s. Size: %s' % (title, bytes_to_mb(size)))
resultlist.append(Result(title, size, url, provider, 'nzb', True))
else:
logger.info('Skipping %s, not all search term words found' % title)
@@ -720,7 +757,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
size = int(item.links[1]['length'])
resultlist.append(Result(title, size, url, provider, 'nzb', True))
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
except Exception as e:
logger.exception("Unhandled exception while parsing feed")
@@ -767,7 +804,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
size = int(item['sizebytes'])
resultlist.append(Result(title, size, url, provider, 'nzb', True))
logger.info('Found %s. Size: %s', title, helpers.bytes_to_mb(size))
logger.info('Found %s. Size: %s', title, bytes_to_mb(size))
except Exception as e:
logger.exception("Unhandled exception")
@@ -790,7 +827,7 @@ def searchNZB(album, new=False, losslessOnly=False, albumlength=None,
def send_to_downloader(data, result, album):
logger.info(
f"Found best result from {result.provider}: <a href=\"{result.url}\">"
f"{result.title}</a> - {helpers.bytes_to_mb(result.size)}"
f"{result.title}</a> - {bytes_to_mb(result.size)}"
)
# Get rid of any dodgy chars here so we can prevent sab from renaming our downloads
kind = result.kind
@@ -798,7 +835,7 @@ def send_to_downloader(data, result, album):
torrentid = None
if kind == 'nzb':
folder_name = helpers.sab_sanitize_foldername(result.title)
folder_name = sab_sanitize_foldername(result.title)
if headphones.CONFIG.NZB_DOWNLOADER == 1:
@@ -820,9 +857,9 @@ def send_to_downloader(data, result, album):
(replace_spaces, replace_dots) = sab.checkConfig()
if replace_dots:
folder_name = helpers.sab_replace_dots(folder_name)
folder_name = sab_replace_dots(folder_name)
if replace_spaces:
folder_name = helpers.sab_replace_spaces(folder_name)
folder_name = sab_replace_spaces(folder_name)
else:
nzb_name = folder_name + '.nzb'
@@ -839,6 +876,15 @@ def send_to_downloader(data, result, album):
except Exception as e:
logger.error('Couldn\'t write NZB file: %s', e)
return
elif kind == 'bandcamp':
folder_name = bandcamp.download(album, result)
logger.info("Setting folder_name to: {}".format(folder_name))
elif kind == 'soulseek':
soulseek.download(user=result.user, filelist=result.files)
folder_name = result.folder
else:
folder_name = '%s - %s [%s]' % (
unidecode(album['ArtistName']).replace('/', '_'),
@@ -849,7 +895,7 @@ def send_to_downloader(data, result, album):
if headphones.CONFIG.TORRENT_DOWNLOADER == 0:
# Get torrent name from .torrent, this is usually used by the torrent client as the folder name
torrent_name = helpers.replace_illegal_chars(folder_name) + '.torrent'
torrent_name = replace_illegal_chars(folder_name) + '.torrent'
download_path = os.path.join(headphones.CONFIG.TORRENTBLACKHOLE_DIR, torrent_name)
if result.url.lower().startswith("magnet:"):
@@ -954,10 +1000,6 @@ def send_to_downloader(data, result, album):
logger.error("Error sending torrent to Deluge. Are you sure it's running? Maybe the torrent already exists?")
return
# This pauses the torrent right after it is added
if headphones.CONFIG.DELUGE_PAUSED:
deluge.setTorrentPause({'hash': torrentid})
# Set Label
if headphones.CONFIG.DELUGE_LABEL:
deluge.setTorrentLabel({'hash': torrentid})
@@ -967,10 +1009,6 @@ def send_to_downloader(data, result, album):
if seed_ratio is not None:
deluge.setSeedRatio({'hash': torrentid, 'ratio': seed_ratio})
# Set move-to directory
if headphones.CONFIG.DELUGE_DONE_DIRECTORY or headphones.CONFIG.DOWNLOAD_TORRENT_DIR:
deluge.setTorrentPath({'hash': torrentid})
# Get folder name from Deluge, it's usually the torrent name
folder_name = deluge.getTorrentFolder({'hash': torrentid})
if folder_name:
@@ -1156,7 +1194,7 @@ def send_to_downloader(data, result, album):
def verifyresult(title, artistterm, term, lossless):
title = re.sub('[\.\-\/\_]', ' ', title)
title = re.sub(r'[\.\-\/\_]', r' ', title)
# if artistterm != 'Various Artists':
#
@@ -1188,16 +1226,16 @@ def verifyresult(title, artistterm, term, lossless):
return False
if headphones.CONFIG.IGNORED_WORDS:
for each_word in helpers.split_string(headphones.CONFIG.IGNORED_WORDS):
for each_word in split_string(headphones.CONFIG.IGNORED_WORDS):
if each_word.lower() in title.lower():
logger.info("Removed '%s' from results because it contains ignored word: '%s'",
title, each_word)
return False
if headphones.CONFIG.REQUIRED_WORDS:
for each_word in helpers.split_string(headphones.CONFIG.REQUIRED_WORDS):
for each_word in split_string(headphones.CONFIG.REQUIRED_WORDS):
if ' OR ' in each_word:
or_words = helpers.split_string(each_word, 'OR')
or_words = split_string(each_word, 'OR')
if any(word.lower() in title.lower() for word in or_words):
continue
else:
@@ -1219,23 +1257,23 @@ def verifyresult(title, artistterm, term, lossless):
title, each_word)
return False
tokens = re.split('\W', term, re.IGNORECASE | re.UNICODE)
tokens = re.split(r'\W', term, re.IGNORECASE | re.UNICODE)
for token in tokens:
if not token:
continue
if token == 'Various' or token == 'Artists' or token == 'VA':
continue
if not re.search('(?:\W|^)+' + token + '(?:\W|$)+', title, re.IGNORECASE | re.UNICODE):
if not has_token(title, token):
cleantoken = ''.join(c for c in token if c not in string.punctuation)
if not not re.search('(?:\W|^)+' + cleantoken + '(?:\W|$)+', title,
re.IGNORECASE | re.UNICODE):
if not has_token(title, cleantoken):
dic = {'!': 'i', '$': 's'}
dumbtoken = helpers.replace_all(token, dic)
if not not re.search('(?:\W|^)+' + dumbtoken + '(?:\W|$)+', title,
re.IGNORECASE | re.UNICODE):
logger.info("Removed from results: %s (missing tokens: %s and %s)", title,
token, cleantoken)
dumbtoken = replace_all(token, dic)
if not has_token(title, dumbtoken):
logger.info(
"Removed from results: %s (missing tokens: [%s, %s, %s])",
title, token, cleantoken, dumbtoken)
return False
return True
@@ -1264,9 +1302,9 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
'*': ''
}
semi_cleanalbum = helpers.replace_all(album['AlbumTitle'], replacements)
semi_cleanalbum = replace_all(album['AlbumTitle'], replacements)
cleanalbum = unidecode(semi_cleanalbum)
semi_cleanartist = helpers.replace_all(album['ArtistName'], replacements)
semi_cleanartist = replace_all(album['ArtistName'], replacements)
cleanartist = unidecode(semi_cleanartist)
# Use provided term if available, otherwise build our own (this code needs to be cleaned up since a lot
@@ -1293,12 +1331,12 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
else:
usersearchterm = ''
semi_clean_artist_term = re.sub('[\.\-\/]', ' ', semi_cleanartist)
semi_clean_album_term = re.sub('[\.\-\/]', ' ', semi_cleanalbum)
semi_clean_artist_term = re.sub(r'[\.\-\/]', r' ', semi_cleanartist)
semi_clean_album_term = re.sub(r'[\.\-\/]', r' ', semi_cleanalbum)
# Replace bad characters in the term
term = re.sub('[\.\-\/]', ' ', term)
artistterm = re.sub('[\.\-\/]', ' ', cleanartist)
albumterm = re.sub('[\.\-\/]', ' ', cleanalbum)
term = re.sub(r'[\.\-\/]', r' ', term)
artistterm = re.sub(r'[\.\-\/]', r' ', cleanartist)
albumterm = re.sub(r'[\.\-\/]', r' ', cleanalbum)
# If Preferred Bitrate and High Limit and Allow Lossless then get both lossy and lossless
if headphones.CONFIG.PREFERRED_QUALITY == 2 and headphones.CONFIG.PREFERRED_BITRATE and headphones.CONFIG.PREFERRED_BITRATE_HIGH_BUFFER and headphones.CONFIG.PREFERRED_BITRATE_ALLOW_LOSSLESS:
@@ -1401,7 +1439,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
if all(word.lower() in title.lower() for word in term.split()):
if size < maxsize and minimumseeders < seeders:
logger.info('Found %s. Size: %s' % (title, helpers.bytes_to_mb(size)))
logger.info('Found %s. Size: %s' % (title, bytes_to_mb(size)))
resultlist.append(Result(title, size, url, provider, 'torrent', True))
else:
logger.info(
@@ -1477,7 +1515,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
size = int(desc_match.group(1))
url = item.link
resultlist.append(Result(title, size, url, provider, 'torrent', True))
logger.info('Found %s. Size: %s', title, helpers.bytes_to_mb(size))
logger.info('Found %s. Size: %s', title, bytes_to_mb(size))
except Exception as e:
logger.error(
"An error occurred while trying to parse the response from Waffles.ch: %s",
@@ -1761,7 +1799,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
# Pirate Bay
if headphones.CONFIG.PIRATEBAY:
provider = "The Pirate Bay"
tpb_term = term.replace("!", "").replace("'", " ")
tpb_term = term.replace("!", "").replace("'", " ").replace(" ", "%20")
# Use proxy if specified
if headphones.CONFIG.PIRATEBAY_PROXY_URL:
@@ -1793,6 +1831,8 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
# Process content
if data:
rows = data.select('table tbody tr')
if not rows:
rows = data.select('table tr')
if not rows:
logger.info("No results found from The Pirate Bay using term: %s" % tpb_term)
@@ -1820,7 +1860,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
formatted_size = re.search('Size (.*),', str(item)).group(1).replace(
'\xa0', ' ')
size = helpers.piratesize(formatted_size)
size = piratesize(formatted_size)
if size < maxsize and minimumseeders < seeds and url is not None:
match = True
@@ -1874,7 +1914,7 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
"href"] # Magnet link. The actual download link is not based on the URL
formatted_size = item.select("td.size-row")[0].text
size = helpers.piratesize(formatted_size)
size = piratesize(formatted_size)
if size < maxsize and minimumseeders < seeds and url is not None:
match = True
@@ -1901,22 +1941,49 @@ def searchTorrent(album, new=False, losslessOnly=False, albumlength=None,
return results
def searchSoulseek(album, new=False, losslessOnly=False, albumlength=None):
# Not using some of the input stuff for now or ever
replacements = {
'...': '',
' & ': ' ',
' = ': ' ',
'?': '',
'$': '',
' + ': ' ',
'"': '',
',': '',
'*': '',
'.': '',
':': ''
}
num_tracks = get_album_track_count(album['AlbumID'])
year = get_year_from_release_date(album['ReleaseDate'])
cleanalbum = unidecode(helpers.replace_all(album['AlbumTitle'], replacements)).strip()
cleanartist = unidecode(helpers.replace_all(album['ArtistName'], replacements)).strip()
results = soulseek.search(artist=cleanartist, album=cleanalbum, year=year, losslessOnly=losslessOnly, num_tracks=num_tracks)
return results
def get_album_track_count(album_id):
# Not sure if this should be considered a helper function.
myDB = db.DBConnection()
track_count = myDB.select('SELECT COUNT(*) as count FROM tracks WHERE AlbumID=?', [album_id])[0]['count']
return track_count
# THIS IS KIND OF A MESS AND PROBABLY NEEDS TO BE CLEANED UP
def preprocess(resultlist):
for result in resultlist:
headers = {'User-Agent': USER_AGENT}
if result.provider in ["The Pirate Bay", "Old Pirate Bay"]:
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 6.3; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) \
Chrome/41.0.2243.2 Safari/537.36'
}
else:
headers = {'User-Agent': USER_AGENT}
if result.kind == 'soulseek':
return True, result
if result.kind == 'torrent':
# rutracker always needs the torrent data
@@ -1962,12 +2029,24 @@ def preprocess(resultlist):
return True, result
# Download the torrent file
if result.provider in ["The Pirate Bay", "Old Pirate Bay"]:
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 6.3; Win64; x64) \
AppleWebKit/537.36 (KHTML, like Gecko) \
Chrome/41.0.2243.2 Safari/537.36'
}
return request.request_content(url=result.url, headers=headers), result
if result.kind == 'magnet':
elif result.kind == 'magnet':
magnet_link = result.url
return "d10:magnet-uri%d:%se" % (len(magnet_link), magnet_link), result
elif result.kind == 'bandcamp':
return True, result
else:
if result.provider == 'headphones':
return request.request_content(

185
headphones/soulseek.py Normal file
View File

@@ -0,0 +1,185 @@
from collections import defaultdict, namedtuple
import os
import time
import slskd_api
import headphones
from headphones import logger
from datetime import datetime, timedelta
Result = namedtuple('Result', ['title', 'size', 'user', 'provider', 'type', 'matches', 'bandwidth', 'hasFreeUploadSlot', 'queueLength', 'files', 'kind', 'url', 'folder'])
def initialize_soulseek_client():
host = headphones.CONFIG.SOULSEEK_API_URL
api_key = headphones.CONFIG.SOULSEEK_API_KEY
return slskd_api.SlskdClient(host=host, api_key=api_key)
# Search logic, calling search and processing fucntions
def search(artist, album, year, num_tracks, losslessOnly):
client = initialize_soulseek_client()
# Stage 1: Search with artist, album, year, and num_tracks
results = execute_search(client, artist, album, year, losslessOnly)
processed_results = process_results(results, losslessOnly, num_tracks)
if processed_results:
return processed_results
# Stage 2: If Stage 1 fails, search with artist, album, and num_tracks (excluding year)
logger.info("Soulseek search stage 1 did not meet criteria. Retrying without year...")
results = execute_search(client, artist, album, None, losslessOnly)
processed_results = process_results(results, losslessOnly, num_tracks)
if processed_results:
return processed_results
# Stage 3: Final attempt, search only with artist and album
logger.info("Soulseek search stage 2 did not meet criteria. Final attempt with only artist and album.")
results = execute_search(client, artist, album, None, losslessOnly)
processed_results = process_results(results, losslessOnly, num_tracks, ignore_track_count=True)
return processed_results
def execute_search(client, artist, album, year, losslessOnly):
search_text = f"{artist} {album}"
if year:
search_text += f" {year}"
if losslessOnly:
search_text += ".flac"
# Actual search
search_response = client.searches.search_text(searchText=search_text, filterResponses=True)
search_id = search_response.get('id')
# Wait for search completion and return response
while not client.searches.state(id=search_id).get('isComplete'):
time.sleep(2)
return client.searches.search_responses(id=search_id)
# Processing the search result passed
def process_results(results, losslessOnly, num_tracks, ignore_track_count=False):
valid_extensions = {'.flac'} if losslessOnly else {'.mp3', '.flac'}
albums = defaultdict(lambda: {'files': [], 'user': None, 'hasFreeUploadSlot': None, 'queueLength': None, 'uploadSpeed': None})
# Extract info from the api response and combine files at album level
for result in results:
user = result.get('username')
hasFreeUploadSlot = result.get('hasFreeUploadSlot')
queueLength = result.get('queueLength')
uploadSpeed = result.get('uploadSpeed')
# Only handle .mp3 and .flac
for file in result.get('files', []):
filename = file.get('filename')
file_extension = os.path.splitext(filename)[1].lower()
if file_extension in valid_extensions:
album_directory = os.path.dirname(filename)
albums[album_directory]['files'].append(file)
# Update metadata only once per album_directory
if albums[album_directory]['user'] is None:
albums[album_directory].update({
'user': user,
'hasFreeUploadSlot': hasFreeUploadSlot,
'queueLength': queueLength,
'uploadSpeed': uploadSpeed,
})
# Filter albums based on num_tracks, add bunch of useful info to the compiled album
final_results = []
for directory, album_data in albums.items():
if ignore_track_count or len(album_data['files']) == num_tracks:
album_title = os.path.basename(directory)
total_size = sum(file.get('size', 0) for file in album_data['files'])
final_results.append(Result(
title=album_title,
size=int(total_size),
user=album_data['user'],
provider="soulseek",
type="soulseek",
matches=True,
bandwidth=album_data['uploadSpeed'],
hasFreeUploadSlot=album_data['hasFreeUploadSlot'],
queueLength=album_data['queueLength'],
files=album_data['files'],
kind='soulseek',
url='http://thisisnot.needed', # URL is needed in other parts of the program.
folder=os.path.basename(directory)
))
return final_results
def download(user, filelist):
client = initialize_soulseek_client()
client.transfers.enqueue(username=user, files=filelist)
def download_completed():
client = initialize_soulseek_client()
all_downloads = client.transfers.get_all_downloads(includeRemoved=False)
album_completion_tracker = {} # Tracks completion state of each album's songs
album_errored_tracker = {} # Tracks albums with errored downloads
# Anything older than 24 hours will be canceled
cutoff_time = datetime.now() - timedelta(hours=24)
# Identify errored and completed albums
for download in all_downloads:
directories = download.get('directories', [])
for directory in directories:
album_part = directory.get('directory', '').split('\\')[-1]
files = directory.get('files', [])
for file_data in files:
state = file_data.get('state', '')
requested_at_str = file_data.get('requestedAt', '1900-01-01 00:00:00')
requested_at = parse_datetime(requested_at_str)
# Initialize or update album entry in trackers
if album_part not in album_completion_tracker:
album_completion_tracker[album_part] = {'total': 0, 'completed': 0, 'errored': 0}
if album_part not in album_errored_tracker:
album_errored_tracker[album_part] = False
album_completion_tracker[album_part]['total'] += 1
if 'Completed, Succeeded' in state:
album_completion_tracker[album_part]['completed'] += 1
elif 'Completed, Errored' in state or requested_at < cutoff_time:
album_completion_tracker[album_part]['errored'] += 1
album_errored_tracker[album_part] = True # Mark album as having errored downloads
# Identify errored albums
errored_albums = {album for album, errored in album_errored_tracker.items() if errored}
# Cancel downloads for errored albums
for download in all_downloads:
directories = download.get('directories', [])
for directory in directories:
album_part = directory.get('directory', '').split('\\')[-1]
files = directory.get('files', [])
for file_data in files:
if album_part in errored_albums:
# Extract 'id' and 'username' for each file to cancel the download
file_id = file_data.get('id', '')
username = file_data.get('username', '')
success = client.transfers.cancel_download(username, file_id)
if not success:
print(f"Failed to cancel download for file ID: {file_id}")
# Clear completed/canceled/errored stuff from client downloads
try:
client.transfers.remove_completed_downloads()
except Exception as e:
print(f"Failed to remove completed downloads: {e}")
# Identify completed albums
completed_albums = {album for album, counts in album_completion_tracker.items() if counts['total'] == counts['completed']}
# Return both completed and errored albums
return completed_albums, errored_albums
def parse_datetime(datetime_string):
# Parse the datetime api response
if '.' in datetime_string:
datetime_string = datetime_string[:datetime_string.index('.')+7]
return datetime.strptime(datetime_string, '%Y-%m-%dT%H:%M:%S.%f')

View File

@@ -1183,6 +1183,7 @@ class WebInterface(object):
"deluge_password": headphones.CONFIG.DELUGE_PASSWORD,
"deluge_label": headphones.CONFIG.DELUGE_LABEL,
"deluge_done_directory": headphones.CONFIG.DELUGE_DONE_DIRECTORY,
"deluge_download_directory": headphones.CONFIG.DELUGE_DOWNLOAD_DIRECTORY,
"deluge_paused": checked(headphones.CONFIG.DELUGE_PAUSED),
"utorrent_host": headphones.CONFIG.UTORRENT_HOST,
"utorrent_username": headphones.CONFIG.UTORRENT_USERNAME,
@@ -1197,6 +1198,8 @@ class WebInterface(object):
"torrent_downloader_deluge": radio(headphones.CONFIG.TORRENT_DOWNLOADER, 3),
"torrent_downloader_qbittorrent": radio(headphones.CONFIG.TORRENT_DOWNLOADER, 4),
"download_dir": headphones.CONFIG.DOWNLOAD_DIR,
"soulseek_download_dir": headphones.CONFIG.SOULSEEK_DOWNLOAD_DIR,
"soulseek_incomplete_download_dir": headphones.CONFIG.SOULSEEK_INCOMPLETE_DOWNLOAD_DIR,
"use_blackhole": checked(headphones.CONFIG.BLACKHOLE),
"blackhole_dir": headphones.CONFIG.BLACKHOLE_DIR,
"usenet_retention": headphones.CONFIG.USENET_RETENTION,
@@ -1296,6 +1299,7 @@ class WebInterface(object):
"prefer_torrents_0": radio(headphones.CONFIG.PREFER_TORRENTS, 0),
"prefer_torrents_1": radio(headphones.CONFIG.PREFER_TORRENTS, 1),
"prefer_torrents_2": radio(headphones.CONFIG.PREFER_TORRENTS, 2),
"prefer_torrents_3": radio(headphones.CONFIG.PREFER_TORRENTS, 3),
"magnet_links_0": radio(headphones.CONFIG.MAGNET_LINKS, 0),
"magnet_links_1": radio(headphones.CONFIG.MAGNET_LINKS, 1),
"magnet_links_2": radio(headphones.CONFIG.MAGNET_LINKS, 2),
@@ -1413,7 +1417,12 @@ class WebInterface(object):
"join_enabled": checked(headphones.CONFIG.JOIN_ENABLED),
"join_onsnatch": checked(headphones.CONFIG.JOIN_ONSNATCH),
"join_apikey": headphones.CONFIG.JOIN_APIKEY,
"join_deviceid": headphones.CONFIG.JOIN_DEVICEID
"join_deviceid": headphones.CONFIG.JOIN_DEVICEID,
"use_bandcamp": checked(headphones.CONFIG.BANDCAMP),
"bandcamp_dir": headphones.CONFIG.BANDCAMP_DIR,
'soulseek_api_url': headphones.CONFIG.SOULSEEK_API_URL,
'soulseek_api_key': headphones.CONFIG.SOULSEEK_API_KEY,
'use_soulseek': checked(headphones.CONFIG.SOULSEEK)
}
for k, v in config.items():
@@ -1482,7 +1491,7 @@ class WebInterface(object):
"songkick_enabled", "songkick_filter_enabled",
"mpc_enabled", "email_enabled", "email_ssl", "email_tls", "email_onsnatch",
"customauth", "idtag", "deluge_paused",
"join_enabled", "join_onsnatch"
"join_enabled", "join_onsnatch", "use_bandcamp"
]
for checked_config in checked_configs:
if checked_config not in kwargs:

View File

@@ -57,9 +57,11 @@ These API's are described in the `CherryPy specification
"""
try:
import pkg_resources
import importlib.metadata as importlib_metadata
except ImportError:
pass
# fall back for python <= 3.7
# This try/except can be removed with py <= 3.7 support
import importlib_metadata
from threading import local as _local
@@ -109,7 +111,7 @@ tree = _cptree.Tree()
try:
__version__ = pkg_resources.require('cherrypy')[0].version
__version__ = importlib_metadata.version('cherrypy')
except Exception:
__version__ = 'unknown'
@@ -181,24 +183,28 @@ def quickstart(root=None, script_name='', config=None):
class _Serving(_local):
"""An interface for registering request and response objects.
Rather than have a separate "thread local" object for the request and
the response, this class works as a single threadlocal container for
both objects (and any others which developers wish to define). In this
way, we can easily dump those objects when we stop/start a new HTTP
conversation, yet still refer to them as module-level globals in a
thread-safe way.
Rather than have a separate "thread local" object for the request
and the response, this class works as a single threadlocal container
for both objects (and any others which developers wish to define).
In this way, we can easily dump those objects when we stop/start a
new HTTP conversation, yet still refer to them as module-level
globals in a thread-safe way.
"""
request = _cprequest.Request(_httputil.Host('127.0.0.1', 80),
_httputil.Host('127.0.0.1', 1111))
"""The request object for the current thread.
In the main thread, and any threads which are not receiving HTTP
requests, this is None.
"""
The request object for the current thread. In the main thread,
and any threads which are not receiving HTTP requests, this is None."""
response = _cprequest.Response()
"""The response object for the current thread.
In the main thread, and any threads which are not receiving HTTP
requests, this is None.
"""
The response object for the current thread. In the main thread,
and any threads which are not receiving HTTP requests, this is None."""
def load(self, request, response):
self.request = request
@@ -316,8 +322,8 @@ class _GlobalLogManager(_cplogging.LogManager):
def __call__(self, *args, **kwargs):
"""Log the given message to the app.log or global log.
Log the given message to the app.log or global
log as appropriate.
Log the given message to the app.log or global log as
appropriate.
"""
# Do NOT use try/except here. See
# https://github.com/cherrypy/cherrypy/issues/945
@@ -330,8 +336,8 @@ class _GlobalLogManager(_cplogging.LogManager):
def access(self):
"""Log an access message to the app.log or global log.
Log the given message to the app.log or global
log as appropriate.
Log the given message to the app.log or global log as
appropriate.
"""
try:
return request.app.log.access()

View File

@@ -313,7 +313,10 @@ class Checker(object):
# -------------------- Specific config warnings -------------------- #
def check_localhost(self):
"""Warn if any socket_host is 'localhost'. See #711."""
"""Warn if any socket_host is 'localhost'.
See #711.
"""
for k, v in cherrypy.config.items():
if k == 'server.socket_host' and v == 'localhost':
warnings.warn("The use of 'localhost' as a socket host can "

View File

@@ -1,5 +1,4 @@
"""
Configuration system for CherryPy.
"""Configuration system for CherryPy.
Configuration in CherryPy is implemented via dictionaries. Keys are strings
which name the mapped value, which may be of any type.
@@ -132,8 +131,8 @@ def _if_filename_register_autoreload(ob):
def merge(base, other):
"""Merge one app config (from a dict, file, or filename) into another.
If the given config is a filename, it will be appended to
the list of files to monitor for "autoreload" changes.
If the given config is a filename, it will be appended to the list
of files to monitor for "autoreload" changes.
"""
_if_filename_register_autoreload(other)

View File

@@ -1,9 +1,10 @@
"""CherryPy dispatchers.
A 'dispatcher' is the object which looks up the 'page handler' callable
and collects config for the current request based on the path_info, other
request attributes, and the application architecture. The core calls the
dispatcher as early as possible, passing it a 'path_info' argument.
and collects config for the current request based on the path_info,
other request attributes, and the application architecture. The core
calls the dispatcher as early as possible, passing it a 'path_info'
argument.
The default dispatcher discovers the page handler by matching path_info
to a hierarchical arrangement of objects, starting at request.app.root.
@@ -21,7 +22,6 @@ import cherrypy
class PageHandler(object):
"""Callable which sets response.body."""
def __init__(self, callable, *args, **kwargs):
@@ -64,8 +64,7 @@ class PageHandler(object):
def test_callable_spec(callable, callable_args, callable_kwargs):
"""
Inspect callable and test to see if the given args are suitable for it.
"""Inspect callable and test to see if the given args are suitable for it.
When an error occurs during the handler's invoking stage there are 2
erroneous cases:
@@ -252,16 +251,16 @@ else:
class Dispatcher(object):
"""CherryPy Dispatcher which walks a tree of objects to find a handler.
The tree is rooted at cherrypy.request.app.root, and each hierarchical
component in the path_info argument is matched to a corresponding nested
attribute of the root object. Matching handlers must have an 'exposed'
attribute which evaluates to True. The special method name "index"
matches a URI which ends in a slash ("/"). The special method name
"default" may match a portion of the path_info (but only when no longer
substring of the path_info matches some other object).
The tree is rooted at cherrypy.request.app.root, and each
hierarchical component in the path_info argument is matched to a
corresponding nested attribute of the root object. Matching handlers
must have an 'exposed' attribute which evaluates to True. The
special method name "index" matches a URI which ends in a slash
("/"). The special method name "default" may match a portion of the
path_info (but only when no longer substring of the path_info
matches some other object).
This is the default, built-in dispatcher for CherryPy.
"""
@@ -306,9 +305,9 @@ class Dispatcher(object):
The second object returned will be a list of names which are
'virtual path' components: parts of the URL which are dynamic,
and were not used when looking up the handler.
These virtual path components are passed to the handler as
positional arguments.
and were not used when looking up the handler. These virtual
path components are passed to the handler as positional
arguments.
"""
request = cherrypy.serving.request
app = request.app
@@ -448,13 +447,11 @@ class Dispatcher(object):
class MethodDispatcher(Dispatcher):
"""Additional dispatch based on cherrypy.request.method.upper().
Methods named GET, POST, etc will be called on an exposed class.
The method names must be all caps; the appropriate Allow header
will be output showing all capitalized method names as allowable
HTTP verbs.
Methods named GET, POST, etc will be called on an exposed class. The
method names must be all caps; the appropriate Allow header will be
output showing all capitalized method names as allowable HTTP verbs.
Note that the containing class must be exposed, not the methods.
"""
@@ -492,16 +489,14 @@ class MethodDispatcher(Dispatcher):
class RoutesDispatcher(object):
"""A Routes based dispatcher for CherryPy."""
def __init__(self, full_result=False, **mapper_options):
"""
Routes dispatcher
"""Routes dispatcher.
Set full_result to True if you wish the controller
and the action to be passed on to the page handler
parameters. By default they won't be.
Set full_result to True if you wish the controller and the
action to be passed on to the page handler parameters. By
default they won't be.
"""
import routes
self.full_result = full_result
@@ -617,8 +612,7 @@ def XMLRPCDispatcher(next_dispatcher=Dispatcher()):
def VirtualHost(next_dispatcher=Dispatcher(), use_x_forwarded_host=True,
**domains):
"""
Select a different handler based on the Host header.
"""Select a different handler based on the Host header.
This can be useful when running multiple sites within one CP server.
It allows several domains to point to different parts of a single

View File

@@ -136,19 +136,17 @@ from cherrypy.lib import httputil as _httputil
class CherryPyException(Exception):
"""A base class for CherryPy exceptions."""
pass
class InternalRedirect(CherryPyException):
"""Exception raised to switch to the handler for a different URL.
This exception will redirect processing to another path within the site
(without informing the client). Provide the new path as an argument when
raising the exception. Provide any params in the querystring for the new
URL.
This exception will redirect processing to another path within the
site (without informing the client). Provide the new path as an
argument when raising the exception. Provide any params in the
querystring for the new URL.
"""
def __init__(self, path, query_string=''):
@@ -173,7 +171,6 @@ class InternalRedirect(CherryPyException):
class HTTPRedirect(CherryPyException):
"""Exception raised when the request should be redirected.
This exception will force a HTTP redirect to the URL or URL's you give it.
@@ -202,7 +199,7 @@ class HTTPRedirect(CherryPyException):
"""The list of URL's to emit."""
encoding = 'utf-8'
"""The encoding when passed urls are not native strings"""
"""The encoding when passed urls are not native strings."""
def __init__(self, urls, status=None, encoding=None):
self.urls = abs_urls = [
@@ -230,8 +227,7 @@ class HTTPRedirect(CherryPyException):
@classproperty
def default_status(cls):
"""
The default redirect status for the request.
"""The default redirect status for the request.
RFC 2616 indicates a 301 response code fits our goal; however,
browser support for 301 is quite messy. Use 302/303 instead. See
@@ -249,8 +245,9 @@ class HTTPRedirect(CherryPyException):
"""Modify cherrypy.response status, headers, and body to represent
self.
CherryPy uses this internally, but you can also use it to create an
HTTPRedirect object and set its output without *raising* the exception.
CherryPy uses this internally, but you can also use it to create
an HTTPRedirect object and set its output without *raising* the
exception.
"""
response = cherrypy.serving.response
response.status = status = self.status
@@ -339,7 +336,6 @@ def clean_headers(status):
class HTTPError(CherryPyException):
"""Exception used to return an HTTP error code (4xx-5xx) to the client.
This exception can be used to automatically send a response using a
@@ -358,7 +354,9 @@ class HTTPError(CherryPyException):
"""
status = None
"""The HTTP status code. May be of type int or str (with a Reason-Phrase).
"""The HTTP status code.
May be of type int or str (with a Reason-Phrase).
"""
code = None
@@ -386,8 +384,9 @@ class HTTPError(CherryPyException):
"""Modify cherrypy.response status, headers, and body to represent
self.
CherryPy uses this internally, but you can also use it to create an
HTTPError object and set its output without *raising* the exception.
CherryPy uses this internally, but you can also use it to create
an HTTPError object and set its output without *raising* the
exception.
"""
response = cherrypy.serving.response
@@ -426,11 +425,10 @@ class HTTPError(CherryPyException):
class NotFound(HTTPError):
"""Exception raised when a URL could not be mapped to any handler (404).
This is equivalent to raising
:class:`HTTPError("404 Not Found") <cherrypy._cperror.HTTPError>`.
This is equivalent to raising :class:`HTTPError("404 Not Found")
<cherrypy._cperror.HTTPError>`.
"""
def __init__(self, path=None):
@@ -477,8 +475,8 @@ _HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
def get_error_page(status, **kwargs):
"""Return an HTML page, containing a pretty error response.
status should be an int or a str.
kwargs will be interpolated into the page template.
status should be an int or a str. kwargs will be interpolated into
the page template.
"""
try:
code, reason, message = _httputil.valid_status(status)
@@ -595,8 +593,8 @@ def bare_error(extrabody=None):
"""Produce status, headers, body for a critical error.
Returns a triple without calling any other questionable functions,
so it should be as error-free as possible. Call it from an HTTP server
if you get errors outside of the request.
so it should be as error-free as possible. Call it from an HTTP
server if you get errors outside of the request.
If extrabody is None, a friendly but rather unhelpful error message
is set in the body. If extrabody is a string, it will be appended

View File

@@ -123,7 +123,6 @@ logfmt = logging.Formatter('%(message)s')
class NullHandler(logging.Handler):
"""A no-op logging handler to silence the logging.lastResort handler."""
def handle(self, record):
@@ -137,15 +136,16 @@ class NullHandler(logging.Handler):
class LogManager(object):
"""An object to assist both simple and advanced logging.
``cherrypy.log`` is an instance of this class.
"""
appid = None
"""The id() of the Application object which owns this log manager. If this
is a global log manager, appid is None."""
"""The id() of the Application object which owns this log manager.
If this is a global log manager, appid is None.
"""
error_log = None
"""The actual :class:`logging.Logger` instance for error messages."""
@@ -317,8 +317,8 @@ class LogManager(object):
def screen(self):
"""Turn stderr/stdout logging on or off.
If you set this to True, it'll add the appropriate StreamHandler for
you. If you set it to False, it will remove the handler.
If you set this to True, it'll add the appropriate StreamHandler
for you. If you set it to False, it will remove the handler.
"""
h = self._get_builtin_handler
has_h = h(self.error_log, 'screen') or h(self.access_log, 'screen')
@@ -414,7 +414,6 @@ class LogManager(object):
class WSGIErrorHandler(logging.Handler):
"A handler class which writes logging records to environ['wsgi.errors']."
def flush(self):
@@ -452,6 +451,8 @@ class WSGIErrorHandler(logging.Handler):
class LazyRfc3339UtcTime(object):
def __str__(self):
"""Return now() in RFC3339 UTC Format."""
now = datetime.datetime.now()
return now.isoformat('T') + 'Z'
"""Return datetime in RFC3339 UTC Format."""
iso_formatted_now = datetime.datetime.now(
datetime.timezone.utc,
).isoformat('T')
return f'{iso_formatted_now!s}Z'

View File

@@ -1,4 +1,4 @@
"""Native adapter for serving CherryPy via mod_python
"""Native adapter for serving CherryPy via mod_python.
Basic usage:

View File

@@ -120,10 +120,10 @@ class NativeGateway(cheroot.server.Gateway):
class CPHTTPServer(cheroot.server.HTTPServer):
"""Wrapper for cheroot.server.HTTPServer.
cheroot has been designed to not reference CherryPy in any way,
so that it can be used in other frameworks and applications.
Therefore, we wrap it here, so we can apply some attributes
from config -> cherrypy.server -> HTTPServer.
cheroot has been designed to not reference CherryPy in any way, so
that it can be used in other frameworks and applications. Therefore,
we wrap it here, so we can apply some attributes from config ->
cherrypy.server -> HTTPServer.
"""
def __init__(self, server_adapter=cherrypy.server):

View File

@@ -248,7 +248,10 @@ def process_multipart_form_data(entity):
def _old_process_multipart(entity):
"""The behavior of 3.2 and lower. Deprecated and will be changed in 3.3."""
"""The behavior of 3.2 and lower.
Deprecated and will be changed in 3.3.
"""
process_multipart(entity)
params = entity.params
@@ -277,7 +280,6 @@ def _old_process_multipart(entity):
# -------------------------------- Entities --------------------------------- #
class Entity(object):
"""An HTTP request body, or MIME multipart body.
This class collects information about the HTTP request entity. When a
@@ -346,13 +348,15 @@ class Entity(object):
content_type = None
"""The value of the Content-Type request header.
If the Entity is part of a multipart payload, this will be the Content-Type
given in the MIME headers for this part.
If the Entity is part of a multipart payload, this will be the
Content-Type given in the MIME headers for this part.
"""
default_content_type = 'application/x-www-form-urlencoded'
"""This defines a default ``Content-Type`` to use if no Content-Type header
is given. The empty string is used for RequestBody, which results in the
is given.
The empty string is used for RequestBody, which results in the
request body not being read or parsed at all. This is by design; a missing
``Content-Type`` header in the HTTP request entity is an error at best,
and a security hole at worst. For multipart parts, however, the MIME spec
@@ -402,8 +406,8 @@ class Entity(object):
part_class = None
"""The class used for multipart parts.
You can replace this with custom subclasses to alter the processing of
multipart parts.
You can replace this with custom subclasses to alter the processing
of multipart parts.
"""
def __init__(self, fp, headers, params=None, parts=None):
@@ -509,7 +513,8 @@ class Entity(object):
"""Return a file-like object into which the request body will be read.
By default, this will return a TemporaryFile. Override as needed.
See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`."""
See also :attr:`cherrypy._cpreqbody.Part.maxrambytes`.
"""
return tempfile.TemporaryFile()
def fullvalue(self):
@@ -525,7 +530,7 @@ class Entity(object):
return value
def decode_entity(self, value):
"""Return a given byte encoded value as a string"""
"""Return a given byte encoded value as a string."""
for charset in self.attempt_charsets:
try:
value = value.decode(charset)
@@ -569,7 +574,6 @@ class Entity(object):
class Part(Entity):
"""A MIME part entity, part of a multipart entity."""
# "The default character set, which must be assumed in the absence of a
@@ -653,8 +657,8 @@ class Part(Entity):
def read_lines_to_boundary(self, fp_out=None):
"""Read bytes from self.fp and return or write them to a file.
If the 'fp_out' argument is None (the default), all bytes read are
returned in a single byte string.
If the 'fp_out' argument is None (the default), all bytes read
are returned in a single byte string.
If the 'fp_out' argument is not None, it must be a file-like
object that supports the 'write' method; all bytes read will be
@@ -755,15 +759,15 @@ class SizedReader:
def read(self, size=None, fp_out=None):
"""Read bytes from the request body and return or write them to a file.
A number of bytes less than or equal to the 'size' argument are read
off the socket. The actual number of bytes read are tracked in
self.bytes_read. The number may be smaller than 'size' when 1) the
client sends fewer bytes, 2) the 'Content-Length' request header
specifies fewer bytes than requested, or 3) the number of bytes read
exceeds self.maxbytes (in which case, 413 is raised).
A number of bytes less than or equal to the 'size' argument are
read off the socket. The actual number of bytes read are tracked
in self.bytes_read. The number may be smaller than 'size' when
1) the client sends fewer bytes, 2) the 'Content-Length' request
header specifies fewer bytes than requested, or 3) the number of
bytes read exceeds self.maxbytes (in which case, 413 is raised).
If the 'fp_out' argument is None (the default), all bytes read are
returned in a single byte string.
If the 'fp_out' argument is None (the default), all bytes read
are returned in a single byte string.
If the 'fp_out' argument is not None, it must be a file-like
object that supports the 'write' method; all bytes read will be
@@ -918,7 +922,6 @@ class SizedReader:
class RequestBody(Entity):
"""The entity of the HTTP request."""
bufsize = 8 * 1024

View File

@@ -16,7 +16,6 @@ from cherrypy.lib import httputil, reprconf, encoding
class Hook(object):
"""A callback and its metadata: failsafe, priority, and kwargs."""
callback = None
@@ -30,10 +29,12 @@ class Hook(object):
from the same call point raise exceptions."""
priority = 50
"""Defines the order of execution for a list of Hooks.
Priority numbers should be limited to the closed interval [0, 100],
but values outside this range are acceptable, as are fractional
values.
"""
Defines the order of execution for a list of Hooks. Priority numbers
should be limited to the closed interval [0, 100], but values outside
this range are acceptable, as are fractional values."""
kwargs = {}
"""
@@ -74,7 +75,6 @@ class Hook(object):
class HookMap(dict):
"""A map of call points to lists of callbacks (Hook objects)."""
def __new__(cls, points=None):
@@ -190,23 +190,23 @@ hookpoints = ['on_start_resource', 'before_request_body',
class Request(object):
"""An HTTP request.
This object represents the metadata of an HTTP request message;
that is, it contains attributes which describe the environment
in which the request URL, headers, and body were sent (if you
want tools to interpret the headers and body, those are elsewhere,
mostly in Tools). This 'metadata' consists of socket data,
transport characteristics, and the Request-Line. This object
also contains data regarding the configuration in effect for
the given URL, and the execution plan for generating a response.
This object represents the metadata of an HTTP request message; that
is, it contains attributes which describe the environment in which
the request URL, headers, and body were sent (if you want tools to
interpret the headers and body, those are elsewhere, mostly in
Tools). This 'metadata' consists of socket data, transport
characteristics, and the Request-Line. This object also contains
data regarding the configuration in effect for the given URL, and
the execution plan for generating a response.
"""
prev = None
"""The previous Request object (if any).
This should be None unless we are processing an InternalRedirect.
"""
The previous Request object (if any). This should be None
unless we are processing an InternalRedirect."""
# Conversation/connection attributes
local = httputil.Host('127.0.0.1', 80)
@@ -216,9 +216,10 @@ class Request(object):
'An httputil.Host(ip, port, hostname) object for the client socket.'
scheme = 'http'
"""The protocol used between client and server.
In most cases, this will be either 'http' or 'https'.
"""
The protocol used between client and server. In most cases,
this will be either 'http' or 'https'."""
server_protocol = 'HTTP/1.1'
"""
@@ -227,25 +228,30 @@ class Request(object):
base = ''
"""The (scheme://host) portion of the requested URL.
In some cases (e.g. when proxying via mod_rewrite), this may contain
path segments which cherrypy.url uses when constructing url's, but
which otherwise are ignored by CherryPy. Regardless, this value
MUST NOT end in a slash."""
which otherwise are ignored by CherryPy. Regardless, this value MUST
NOT end in a slash.
"""
# Request-Line attributes
request_line = ''
"""The complete Request-Line received from the client.
This is a single string consisting of the request method, URI, and
protocol version (joined by spaces). Any final CRLF is removed.
"""
The complete Request-Line received from the client. This is a
single string consisting of the request method, URI, and protocol
version (joined by spaces). Any final CRLF is removed."""
method = 'GET'
"""Indicates the HTTP method to be performed on the resource identified by
the Request-URI.
Common methods include GET, HEAD, POST, PUT, and DELETE. CherryPy
allows any extension method; however, various HTTP servers and
gateways may restrict the set of allowable methods. CherryPy
applications SHOULD restrict the set (on a per-URI basis).
"""
Indicates the HTTP method to be performed on the resource identified
by the Request-URI. Common methods include GET, HEAD, POST, PUT, and
DELETE. CherryPy allows any extension method; however, various HTTP
servers and gateways may restrict the set of allowable methods.
CherryPy applications SHOULD restrict the set (on a per-URI basis)."""
query_string = ''
"""
@@ -277,22 +283,26 @@ class Request(object):
A dict which combines query string (GET) and request entity (POST)
variables. This is populated in two stages: GET params are added
before the 'on_start_resource' hook, and POST params are added
between the 'before_request_body' and 'before_handler' hooks."""
between the 'before_request_body' and 'before_handler' hooks.
"""
# Message attributes
header_list = []
"""A list of the HTTP request headers as (name, value) tuples.
In general, you should use request.headers (a dict) instead.
"""
A list of the HTTP request headers as (name, value) tuples.
In general, you should use request.headers (a dict) instead."""
headers = httputil.HeaderMap()
"""
A dict-like object containing the request headers. Keys are header
"""A dict-like object containing the request headers.
Keys are header
names (in Title-Case format); however, you may get and set them in
a case-insensitive manner. That is, headers['Content-Type'] and
headers['content-type'] refer to the same value. Values are header
values (decoded according to :rfc:`2047` if necessary). See also:
httputil.HeaderMap, httputil.HeaderElement."""
httputil.HeaderMap, httputil.HeaderElement.
"""
cookie = SimpleCookie()
"""See help(Cookie)."""
@@ -336,7 +346,8 @@ class Request(object):
or multipart, this will be None. Otherwise, this will be an instance
of :class:`RequestBody<cherrypy._cpreqbody.RequestBody>` (which you
can .read()); this value is set between the 'before_request_body' and
'before_handler' hooks (assuming that process_request_body is True)."""
'before_handler' hooks (assuming that process_request_body is True).
"""
# Dispatch attributes
dispatch = cherrypy.dispatch.Dispatcher()
@@ -347,23 +358,24 @@ class Request(object):
calls the dispatcher as early as possible, passing it a 'path_info'
argument.
The default dispatcher discovers the page handler by matching path_info
to a hierarchical arrangement of objects, starting at request.app.root.
See help(cherrypy.dispatch) for more information."""
The default dispatcher discovers the page handler by matching
path_info to a hierarchical arrangement of objects, starting at
request.app.root. See help(cherrypy.dispatch) for more information.
"""
script_name = ''
"""
The 'mount point' of the application which is handling this request.
"""The 'mount point' of the application which is handling this request.
This attribute MUST NOT end in a slash. If the script_name refers to
the root of the URI, it MUST be an empty string (not "/").
"""
path_info = '/'
"""The 'relative path' portion of the Request-URI.
This is relative to the script_name ('mount point') of the
application which is handling this request.
"""
The 'relative path' portion of the Request-URI. This is relative
to the script_name ('mount point') of the application which is
handling this request."""
login = None
"""
@@ -391,14 +403,16 @@ class Request(object):
of the form: {Toolbox.namespace: {Tool.name: config dict}}."""
config = None
"""A flat dict of all configuration entries which apply to the current
request.
These entries are collected from global config, application config
(based on request.path_info), and from handler config (exactly how
is governed by the request.dispatch object in effect for this
request; by default, handler config can be attached anywhere in the
tree between request.app.root and the final handler, and inherits
downward).
"""
A flat dict of all configuration entries which apply to the
current request. These entries are collected from global config,
application config (based on request.path_info), and from handler
config (exactly how is governed by the request.dispatch object in
effect for this request; by default, handler config can be attached
anywhere in the tree between request.app.root and the final handler,
and inherits downward)."""
is_index = None
"""
@@ -409,13 +423,14 @@ class Request(object):
the trailing slash. See cherrypy.tools.trailing_slash."""
hooks = HookMap(hookpoints)
"""
A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
"""A HookMap (dict-like object) of the form: {hookpoint: [hook, ...]}.
Each key is a str naming the hook point, and each value is a list
of hooks which will be called at that hook point during this request.
The list of hooks is generally populated as early as possible (mostly
from Tools specified in config), but may be extended at any time.
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools."""
See also: _cprequest.Hook, _cprequest.HookMap, and cherrypy.tools.
"""
error_response = cherrypy.HTTPError(500).set_response
"""
@@ -428,12 +443,11 @@ class Request(object):
error response to the user-agent."""
error_page = {}
"""
A dict of {error code: response filename or callable} pairs.
"""A dict of {error code: response filename or callable} pairs.
The error code must be an int representing a given HTTP error code,
or the string 'default', which will be used if no matching entry
is found for a given numeric code.
or the string 'default', which will be used if no matching entry is
found for a given numeric code.
If a filename is provided, the file should contain a Python string-
formatting template, and can expect by default to receive format
@@ -447,8 +461,8 @@ class Request(object):
iterable of strings which will be set to response.body. It may also
override headers or perform any other processing.
If no entry is given for an error code, and no 'default' entry exists,
a default template will be used.
If no entry is given for an error code, and no 'default' entry
exists, a default template will be used.
"""
show_tracebacks = True
@@ -473,9 +487,10 @@ class Request(object):
"""True once the close method has been called, False otherwise."""
stage = None
"""A string containing the stage reached in the request-handling process.
This is useful when debugging a live server with hung requests.
"""
A string containing the stage reached in the request-handling process.
This is useful when debugging a live server with hung requests."""
unique_id = None
"""A lazy object generating and memorizing UUID4 on ``str()`` render."""
@@ -492,9 +507,10 @@ class Request(object):
server_protocol='HTTP/1.1'):
"""Populate a new Request object.
local_host should be an httputil.Host object with the server info.
remote_host should be an httputil.Host object with the client info.
scheme should be a string, either "http" or "https".
local_host should be an httputil.Host object with the server
info. remote_host should be an httputil.Host object with the
client info. scheme should be a string, either "http" or
"https".
"""
self.local = local_host
self.remote = remote_host
@@ -514,7 +530,10 @@ class Request(object):
self.unique_id = LazyUUID4()
def close(self):
"""Run cleanup code. (Core)"""
"""Run cleanup code.
(Core)
"""
if not self.closed:
self.closed = True
self.stage = 'on_end_request'
@@ -551,7 +570,6 @@ class Request(object):
Consumer code (HTTP servers) should then access these response
attributes to build the outbound stream.
"""
response = cherrypy.serving.response
self.stage = 'run'
@@ -631,7 +649,10 @@ class Request(object):
return response
def respond(self, path_info):
"""Generate a response for the resource at self.path_info. (Core)"""
"""Generate a response for the resource at self.path_info.
(Core)
"""
try:
try:
try:
@@ -702,7 +723,10 @@ class Request(object):
response.finalize()
def process_query_string(self):
"""Parse the query string into Python structures. (Core)"""
"""Parse the query string into Python structures.
(Core)
"""
try:
p = httputil.parse_query_string(
self.query_string, encoding=self.query_string_encoding)
@@ -715,7 +739,10 @@ class Request(object):
self.params.update(p)
def process_headers(self):
"""Parse HTTP header data into Python structures. (Core)"""
"""Parse HTTP header data into Python structures.
(Core)
"""
# Process the headers into self.headers
headers = self.headers
for name, value in self.header_list:
@@ -751,7 +778,10 @@ class Request(object):
self.base = '%s://%s' % (self.scheme, host)
def get_resource(self, path):
"""Call a dispatcher (which sets self.handler and .config). (Core)"""
"""Call a dispatcher (which sets self.handler and .config).
(Core)
"""
# First, see if there is a custom dispatch at this URI. Custom
# dispatchers can only be specified in app.config, not in _cp_config
# (since custom dispatchers may not even have an app.root).
@@ -762,7 +792,10 @@ class Request(object):
dispatch(path)
def handle_error(self):
"""Handle the last unanticipated exception. (Core)"""
"""Handle the last unanticipated exception.
(Core)
"""
try:
self.hooks.run('before_error_response')
if self.error_response:
@@ -776,7 +809,6 @@ class Request(object):
class ResponseBody(object):
"""The body of the HTTP response (the response entity)."""
unicode_err = ('Page handlers MUST return bytes. Use tools.encode '
@@ -802,18 +834,18 @@ class ResponseBody(object):
class Response(object):
"""An HTTP Response, including status, headers, and body."""
status = ''
"""The HTTP Status-Code and Reason-Phrase."""
header_list = []
"""
A list of the HTTP response headers as (name, value) tuples.
"""A list of the HTTP response headers as (name, value) tuples.
In general, you should use response.headers (a dict) instead. This
attribute is generated from response.headers and is not valid until
after the finalize phase."""
after the finalize phase.
"""
headers = httputil.HeaderMap()
"""
@@ -833,7 +865,10 @@ class Response(object):
"""The body (entity) of the HTTP response."""
time = None
"""The value of time.time() when created. Use in HTTP dates."""
"""The value of time.time() when created.
Use in HTTP dates.
"""
stream = False
"""If False, buffer the response body."""
@@ -861,15 +896,15 @@ class Response(object):
return new_body
def _flush_body(self):
"""
Discard self.body but consume any generator such that
any finalization can occur, such as is required by
caching.tee_output().
"""
"""Discard self.body but consume any generator such that any
finalization can occur, such as is required by caching.tee_output()."""
consume(iter(self.body))
def finalize(self):
"""Transform headers (and cookies) into self.header_list. (Core)"""
"""Transform headers (and cookies) into self.header_list.
(Core)
"""
try:
code, reason, _ = httputil.valid_status(self.status)
except ValueError:

View File

@@ -50,7 +50,8 @@ class Server(ServerAdapter):
"""If given, the name of the UNIX socket to use instead of TCP/IP.
When this option is not None, the `socket_host` and `socket_port` options
are ignored."""
are ignored.
"""
socket_queue_size = 5
"""The 'backlog' argument to socket.listen(); specifies the maximum number
@@ -79,17 +80,24 @@ class Server(ServerAdapter):
"""The number of worker threads to start up in the pool."""
thread_pool_max = -1
"""The maximum size of the worker-thread pool. Use -1 to indicate no limit.
"""The maximum size of the worker-thread pool.
Use -1 to indicate no limit.
"""
max_request_header_size = 500 * 1024
"""The maximum number of bytes allowable in the request headers.
If exceeded, the HTTP server should return "413 Request Entity Too Large".
If exceeded, the HTTP server should return "413 Request Entity Too
Large".
"""
max_request_body_size = 100 * 1024 * 1024
"""The maximum number of bytes allowable in the request body. If exceeded,
the HTTP server should return "413 Request Entity Too Large"."""
"""The maximum number of bytes allowable in the request body.
If exceeded, the HTTP server should return "413 Request Entity Too
Large".
"""
instance = None
"""If not None, this should be an HTTP server instance (such as
@@ -119,7 +127,8 @@ class Server(ServerAdapter):
the builtin WSGI server. Builtin options are: 'builtin' (to
use the SSL library built into recent versions of Python).
You may also register your own classes in the
cheroot.server.ssl_adapters dict."""
cheroot.server.ssl_adapters dict.
"""
statistics = False
"""Turns statistics-gathering on or off for aware HTTP servers."""
@@ -129,11 +138,13 @@ class Server(ServerAdapter):
wsgi_version = (1, 0)
"""The WSGI version tuple to use with the builtin WSGI server.
The provided options are (1, 0) [which includes support for PEP 3333,
which declares it covers WSGI version 1.0.1 but still mandates the
wsgi.version (1, 0)] and ('u', 0), an experimental unicode version.
You may create and register your own experimental versions of the WSGI
protocol by adding custom classes to the cheroot.server.wsgi_gateways dict.
The provided options are (1, 0) [which includes support for PEP
3333, which declares it covers WSGI version 1.0.1 but still mandates
the wsgi.version (1, 0)] and ('u', 0), an experimental unicode
version. You may create and register your own experimental versions
of the WSGI protocol by adding custom classes to the
cheroot.server.wsgi_gateways dict.
"""
peercreds = False
@@ -184,7 +195,8 @@ class Server(ServerAdapter):
def bind_addr(self):
"""Return bind address.
A (host, port) tuple for TCP sockets or a str for Unix domain sockts.
A (host, port) tuple for TCP sockets or a str for Unix domain
sockets.
"""
if self.socket_file:
return self.socket_file

View File

@@ -1,7 +1,7 @@
"""CherryPy tools. A "tool" is any helper, adapted to CP.
Tools are usually designed to be used in a variety of ways (although some
may only offer one if they choose):
Tools are usually designed to be used in a variety of ways (although
some may only offer one if they choose):
Library calls
All tools are callables that can be used wherever needed.
@@ -48,10 +48,10 @@ _attr_error = (
class Tool(object):
"""A registered function for use with CherryPy request-processing hooks.
help(tool.callable) should give you more information about this Tool.
help(tool.callable) should give you more information about this
Tool.
"""
namespace = 'tools'
@@ -135,8 +135,8 @@ class Tool(object):
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
The standard CherryPy request object will automatically call
this method when the tool is "turned on" in config.
"""
conf = self._merged_args()
p = conf.pop('priority', None)
@@ -147,15 +147,15 @@ class Tool(object):
class HandlerTool(Tool):
"""Tool which is called 'before main', that may skip normal handlers.
If the tool successfully handles the request (by setting response.body),
if should return True. This will cause CherryPy to skip any 'normal' page
handler. If the tool did not handle the request, it should return False
to tell CherryPy to continue on and call the normal page handler. If the
tool is declared AS a page handler (see the 'handler' method), returning
False will raise NotFound.
If the tool successfully handles the request (by setting
response.body), if should return True. This will cause CherryPy to
skip any 'normal' page handler. If the tool did not handle the
request, it should return False to tell CherryPy to continue on and
call the normal page handler. If the tool is declared AS a page
handler (see the 'handler' method), returning False will raise
NotFound.
"""
def __init__(self, callable, name=None):
@@ -185,8 +185,8 @@ class HandlerTool(Tool):
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
The standard CherryPy request object will automatically call
this method when the tool is "turned on" in config.
"""
conf = self._merged_args()
p = conf.pop('priority', None)
@@ -197,7 +197,6 @@ class HandlerTool(Tool):
class HandlerWrapperTool(Tool):
"""Tool which wraps request.handler in a provided wrapper function.
The 'newhandler' arg must be a handler wrapper function that takes a
@@ -232,7 +231,6 @@ class HandlerWrapperTool(Tool):
class ErrorTool(Tool):
"""Tool which is used to replace the default request.error_response."""
def __init__(self, callable, name=None):
@@ -244,8 +242,8 @@ class ErrorTool(Tool):
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
The standard CherryPy request object will automatically call
this method when the tool is "turned on" in config.
"""
cherrypy.serving.request.error_response = self._wrapper
@@ -254,7 +252,6 @@ class ErrorTool(Tool):
class SessionTool(Tool):
"""Session Tool for CherryPy.
sessions.locking
@@ -282,8 +279,8 @@ class SessionTool(Tool):
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
The standard CherryPy request object will automatically call
this method when the tool is "turned on" in config.
"""
hooks = cherrypy.serving.request.hooks
@@ -325,7 +322,6 @@ class SessionTool(Tool):
class XMLRPCController(object):
"""A Controller (page handler collection) for XML-RPC.
To use it, have your controllers subclass this base class (it will
@@ -392,7 +388,6 @@ class SessionAuthTool(HandlerTool):
class CachingTool(Tool):
"""Caching Tool for CherryPy."""
def _wrapper(self, **kwargs):
@@ -416,11 +411,11 @@ class CachingTool(Tool):
class Toolbox(object):
"""A collection of Tools.
This object also functions as a config namespace handler for itself.
Custom toolboxes should be added to each Application's toolboxes dict.
Custom toolboxes should be added to each Application's toolboxes
dict.
"""
def __init__(self, namespace):

View File

@@ -10,19 +10,22 @@ from cherrypy.lib import httputil, reprconf
class Application(object):
"""A CherryPy Application.
Servers and gateways should not instantiate Request objects directly.
Instead, they should ask an Application object for a request object.
Servers and gateways should not instantiate Request objects
directly. Instead, they should ask an Application object for a
request object.
An instance of this class may also be used as a WSGI callable
(WSGI application object) for itself.
An instance of this class may also be used as a WSGI callable (WSGI
application object) for itself.
"""
root = None
"""The top-most container of page handlers for this app. Handlers should
be arranged in a hierarchy of attributes, matching the expected URI
hierarchy; the default dispatcher then searches this hierarchy for a
matching handler. When using a dispatcher other than the default,
this value may be None."""
"""The top-most container of page handlers for this app.
Handlers should be arranged in a hierarchy of attributes, matching
the expected URI hierarchy; the default dispatcher then searches
this hierarchy for a matching handler. When using a dispatcher other
than the default, this value may be None.
"""
config = {}
"""A dict of {path: pathconf} pairs, where 'pathconf' is itself a dict
@@ -32,10 +35,16 @@ class Application(object):
toolboxes = {'tools': cherrypy.tools}
log = None
"""A LogManager instance. See _cplogging."""
"""A LogManager instance.
See _cplogging.
"""
wsgiapp = None
"""A CPWSGIApp instance. See _cpwsgi."""
"""A CPWSGIApp instance.
See _cpwsgi.
"""
request_class = _cprequest.Request
response_class = _cprequest.Response
@@ -82,12 +91,15 @@ class Application(object):
def script_name(self): # noqa: D401; irrelevant for properties
"""The URI "mount point" for this app.
A mount point is that portion of the URI which is constant for all URIs
that are serviced by this application; it does not include scheme,
host, or proxy ("virtual host") portions of the URI.
A mount point is that portion of the URI which is constant for
all URIs that are serviced by this application; it does not
include scheme, host, or proxy ("virtual host") portions of the
URI.
For example, if script_name is "/my/cool/app", then the URL
"http://www.example.com/my/cool/app/page1" might be handled by a
For example, if script_name is "/my/cool/app", then the URL "
http://www.example.com/my/cool/app/page1"
might be handled by a
"page1" method on the root object.
The value of script_name MUST NOT end in a slash. If the script_name
@@ -171,9 +183,9 @@ class Application(object):
class Tree(object):
"""A registry of CherryPy applications, mounted at diverse points.
An instance of this class may also be used as a WSGI callable
(WSGI application object), in which case it dispatches to all
mounted apps.
An instance of this class may also be used as a WSGI callable (WSGI
application object), in which case it dispatches to all mounted
apps.
"""
apps = {}

View File

@@ -1,10 +1,10 @@
"""WSGI interface (see PEP 333 and 3333).
Note that WSGI environ keys and values are 'native strings'; that is,
whatever the type of "" is. For Python 2, that's a byte string; for Python 3,
it's a unicode string. But PEP 3333 says: "even if Python's str type is
actually Unicode "under the hood", the content of native strings must
still be translatable to bytes via the Latin-1 encoding!"
whatever the type of "" is. For Python 2, that's a byte string; for
Python 3, it's a unicode string. But PEP 3333 says: "even if Python's
str type is actually Unicode "under the hood", the content of native
strings must still be translatable to bytes via the Latin-1 encoding!"
"""
import sys as _sys
@@ -34,7 +34,6 @@ def downgrade_wsgi_ux_to_1x(environ):
class VirtualHost(object):
"""Select a different WSGI application based on the Host header.
This can be useful when running multiple sites within one CP server.
@@ -56,7 +55,10 @@ class VirtualHost(object):
cherrypy.tree.graft(vhost)
"""
default = None
"""Required. The default WSGI application."""
"""Required.
The default WSGI application.
"""
use_x_forwarded_host = True
"""If True (the default), any "X-Forwarded-Host"
@@ -65,11 +67,12 @@ class VirtualHost(object):
domains = {}
"""A dict of {host header value: application} pairs.
The incoming "Host" request header is looked up in this dict,
and, if a match is found, the corresponding WSGI application
will be called instead of the default. Note that you often need
separate entries for "example.com" and "www.example.com".
In addition, "Host" headers may contain the port number.
The incoming "Host" request header is looked up in this dict, and,
if a match is found, the corresponding WSGI application will be
called instead of the default. Note that you often need separate
entries for "example.com" and "www.example.com". In addition, "Host"
headers may contain the port number.
"""
def __init__(self, default, domains=None, use_x_forwarded_host=True):
@@ -89,7 +92,6 @@ class VirtualHost(object):
class InternalRedirector(object):
"""WSGI middleware that handles raised cherrypy.InternalRedirect."""
def __init__(self, nextapp, recursive=False):
@@ -137,7 +139,6 @@ class InternalRedirector(object):
class ExceptionTrapper(object):
"""WSGI middleware that traps exceptions."""
def __init__(self, nextapp, throws=(KeyboardInterrupt, SystemExit)):
@@ -226,7 +227,6 @@ class _TrappedResponse(object):
class AppResponse(object):
"""WSGI response iterable for CherryPy applications."""
def __init__(self, environ, start_response, cpapp):
@@ -277,7 +277,10 @@ class AppResponse(object):
return next(self.iter_response)
def close(self):
"""Close and de-reference the current request and response. (Core)"""
"""Close and de-reference the current request and response.
(Core)
"""
streaming = _cherrypy.serving.response.stream
self.cpapp.release_serving()
@@ -380,18 +383,20 @@ class AppResponse(object):
class CPWSGIApp(object):
"""A WSGI application object for a CherryPy Application."""
pipeline = [
('ExceptionTrapper', ExceptionTrapper),
('InternalRedirector', InternalRedirector),
]
"""A list of (name, wsgiapp) pairs. Each 'wsgiapp' MUST be a
constructor that takes an initial, positional 'nextapp' argument,
plus optional keyword arguments, and returns a WSGI application
(that takes environ and start_response arguments). The 'name' can
be any you choose, and will correspond to keys in self.config."""
"""A list of (name, wsgiapp) pairs.
Each 'wsgiapp' MUST be a constructor that takes an initial,
positional 'nextapp' argument, plus optional keyword arguments, and
returns a WSGI application (that takes environ and start_response
arguments). The 'name' can be any you choose, and will correspond to
keys in self.config.
"""
head = None
"""Rather than nest all apps in the pipeline on each call, it's only
@@ -399,9 +404,12 @@ class CPWSGIApp(object):
this to None again if you change self.pipeline after calling self."""
config = {}
"""A dict whose keys match names listed in the pipeline. Each
value is a further dict which will be passed to the corresponding
named WSGI callable (from the pipeline) as keyword arguments."""
"""A dict whose keys match names listed in the pipeline.
Each value is a further dict which will be passed to the
corresponding named WSGI callable (from the pipeline) as keyword
arguments.
"""
response_class = AppResponse
"""The class to instantiate and return as the next app in the WSGI chain.
@@ -417,8 +425,8 @@ class CPWSGIApp(object):
def tail(self, environ, start_response):
"""WSGI application callable for the actual CherryPy application.
You probably shouldn't call this; call self.__call__ instead,
so that any WSGI middleware in self.pipeline can run first.
You probably shouldn't call this; call self.__call__ instead, so
that any WSGI middleware in self.pipeline can run first.
"""
return self.response_class(environ, start_response, self.cpapp)

View File

@@ -1,7 +1,7 @@
"""
WSGI server interface (see PEP 333).
"""WSGI server interface (see PEP 333).
This adds some CP-specific bits to the framework-agnostic cheroot package.
This adds some CP-specific bits to the framework-agnostic cheroot
package.
"""
import sys
@@ -35,10 +35,11 @@ class CPWSGIHTTPRequest(cheroot.server.HTTPRequest):
class CPWSGIServer(cheroot.wsgi.Server):
"""Wrapper for cheroot.wsgi.Server.
cheroot has been designed to not reference CherryPy in any way,
so that it can be used in other frameworks and applications. Therefore,
we wrap it here, so we can set our own mount points from cherrypy.tree
and apply some attributes from config -> cherrypy.server -> wsgi.Server.
cheroot has been designed to not reference CherryPy in any way, so
that it can be used in other frameworks and applications. Therefore,
we wrap it here, so we can set our own mount points from
cherrypy.tree and apply some attributes from config ->
cherrypy.server -> wsgi.Server.
"""
fmt = 'CherryPy/{cherrypy.__version__} {cheroot.wsgi.Server.version}'

View File

@@ -137,7 +137,6 @@ def popargs(*args, **kwargs):
class Root:
def index(self):
#...
"""
# Since keyword arg comes after *args, we have to process it ourselves
# for lower versions of python.
@@ -201,16 +200,17 @@ def url(path='', qs='', script_name=None, base=None, relative=None):
If it does not start with a slash, this returns
(base + script_name [+ request.path_info] + path + qs).
If script_name is None, cherrypy.request will be used
to find a script_name, if available.
If script_name is None, cherrypy.request will be used to find a
script_name, if available.
If base is None, cherrypy.request.base will be used (if available).
Note that you can use cherrypy.tools.proxy to change this.
Finally, note that this function can be used to obtain an absolute URL
for the current request path (minus the querystring) by passing no args.
If you call url(qs=cherrypy.request.query_string), you should get the
original browser URL (assuming no internal redirections).
Finally, note that this function can be used to obtain an absolute
URL for the current request path (minus the querystring) by passing
no args. If you call url(qs=cherrypy.request.query_string), you
should get the original browser URL (assuming no internal
redirections).
If relative is None or not provided, request.app.relative_urls will
be used (if available, else False). If False, the output will be an
@@ -320,8 +320,8 @@ def normalize_path(path):
class _ClassPropertyDescriptor(object):
"""Descript for read-only class-based property.
Turns a classmethod-decorated func into a read-only property of that class
type (means the value cannot be set).
Turns a classmethod-decorated func into a read-only property of that
class type (means the value cannot be set).
"""
def __init__(self, fget, fset=None):

View File

@@ -1,5 +1,4 @@
"""
JSON support.
"""JSON support.
Expose preferred json module as json and provide encode/decode
convenience functions.

View File

@@ -6,8 +6,8 @@ def is_iterator(obj):
(i.e. like a generator).
This will return False for objects which are iterable,
but not iterators themselves.
This will return False for objects which are iterable, but not
iterators themselves.
"""
from types import GeneratorType
if isinstance(obj, GeneratorType):

View File

@@ -18,7 +18,6 @@ as the credentials store::
'tools.auth_basic.accept_charset': 'UTF-8',
}
app_config = { '/' : basic_auth }
"""
import binascii

View File

@@ -55,7 +55,7 @@ def TRACE(msg):
def get_ha1_dict_plain(user_password_dict):
"""Returns a get_ha1 function which obtains a plaintext password from a
"""Return a get_ha1 function which obtains a plaintext password from a
dictionary of the form: {username : password}.
If you want a simple dictionary-based authentication scheme, with plaintext
@@ -72,7 +72,7 @@ def get_ha1_dict_plain(user_password_dict):
def get_ha1_dict(user_ha1_dict):
"""Returns a get_ha1 function which obtains a HA1 password hash from a
"""Return a get_ha1 function which obtains a HA1 password hash from a
dictionary of the form: {username : HA1}.
If you want a dictionary-based authentication scheme, but with
@@ -87,7 +87,7 @@ def get_ha1_dict(user_ha1_dict):
def get_ha1_file_htdigest(filename):
"""Returns a get_ha1 function which obtains a HA1 password hash from a
"""Return a get_ha1 function which obtains a HA1 password hash from a
flat file with lines of the same format as that produced by the Apache
htdigest utility. For example, for realm 'wonderland', username 'alice',
and password '4x5istwelve', the htdigest line would be::
@@ -135,7 +135,7 @@ def synthesize_nonce(s, key, timestamp=None):
def H(s):
"""The hash function H"""
"""The hash function H."""
return md5_hex(s)
@@ -259,10 +259,11 @@ class HttpDigestAuthorization(object):
return False
def is_nonce_stale(self, max_age_seconds=600):
"""Returns True if a validated nonce is stale. The nonce contains a
timestamp in plaintext and also a secure hash of the timestamp.
You should first validate the nonce to ensure the plaintext
timestamp is not spoofed.
"""Return True if a validated nonce is stale.
The nonce contains a timestamp in plaintext and also a secure
hash of the timestamp. You should first validate the nonce to
ensure the plaintext timestamp is not spoofed.
"""
try:
timestamp, hashpart = self.nonce.split(':', 1)
@@ -275,7 +276,10 @@ class HttpDigestAuthorization(object):
return True
def HA2(self, entity_body=''):
"""Returns the H(A2) string. See :rfc:`2617` section 3.2.2.3."""
"""Return the H(A2) string.
See :rfc:`2617` section 3.2.2.3.
"""
# RFC 2617 3.2.2.3
# If the "qop" directive's value is "auth" or is unspecified,
# then A2 is:
@@ -306,7 +310,6 @@ class HttpDigestAuthorization(object):
4.3. This refers to the entity the user agent sent in the
request which has the Authorization header. Typically GET
requests don't have an entity, and POST requests do.
"""
ha2 = self.HA2(entity_body)
# Request-Digest -- RFC 2617 3.2.2.1
@@ -395,7 +398,6 @@ def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'):
key
A secret string known only to the server, used in the synthesis
of nonces.
"""
request = cherrypy.serving.request
@@ -447,9 +449,7 @@ def digest_auth(realm, get_ha1, key, debug=False, accept_charset='utf-8'):
def _respond_401(realm, key, accept_charset, debug, **kwargs):
"""
Respond with 401 status and a WWW-Authenticate header
"""
"""Respond with 401 status and a WWW-Authenticate header."""
header = www_authenticate(
realm, key,
accept_charset=accept_charset,

View File

@@ -42,7 +42,6 @@ from cherrypy.lib import cptools, httputil
class Cache(object):
"""Base class for Cache implementations."""
def get(self):
@@ -64,17 +63,16 @@ class Cache(object):
# ------------------------------ Memory Cache ------------------------------- #
class AntiStampedeCache(dict):
"""A storage system for cached items which reduces stampede collisions."""
def wait(self, key, timeout=5, debug=False):
"""Return the cached value for the given key, or None.
If timeout is not None, and the value is already
being calculated by another thread, wait until the given timeout has
elapsed. If the value is available before the timeout expires, it is
returned. If not, None is returned, and a sentinel placed in the cache
to signal other threads to wait.
If timeout is not None, and the value is already being
calculated by another thread, wait until the given timeout has
elapsed. If the value is available before the timeout expires,
it is returned. If not, None is returned, and a sentinel placed
in the cache to signal other threads to wait.
If timeout is None, no waiting is performed nor sentinels used.
"""
@@ -127,7 +125,6 @@ class AntiStampedeCache(dict):
class MemoryCache(Cache):
"""An in-memory cache for varying response content.
Each key in self.store is a URI, and each value is an AntiStampedeCache.
@@ -381,7 +378,10 @@ def get(invalid_methods=('POST', 'PUT', 'DELETE'), debug=False, **kwargs):
def tee_output():
"""Tee response output to cache storage. Internal."""
"""Tee response output to cache storage.
Internal.
"""
# Used by CachingTool by attaching to request.hooks
request = cherrypy.serving.request
@@ -441,7 +441,6 @@ def expires(secs=0, force=False, debug=False):
* Expires
If any are already present, none of the above response headers are set.
"""
response = cherrypy.serving.response

View File

@@ -184,7 +184,6 @@ To report statistics::
To format statistics reports::
See 'Reporting', above.
"""
import logging
@@ -254,7 +253,6 @@ def proc_time(s):
class ByteCountWrapper(object):
"""Wraps a file-like object, counting the number of bytes read."""
def __init__(self, rfile):
@@ -307,7 +305,6 @@ def _get_threading_ident():
class StatsTool(cherrypy.Tool):
"""Record various information about the current request."""
def __init__(self):
@@ -316,8 +313,8 @@ class StatsTool(cherrypy.Tool):
def _setup(self):
"""Hook this tool into cherrypy.request.
The standard CherryPy request object will automatically call this
method when the tool is "turned on" in config.
The standard CherryPy request object will automatically call
this method when the tool is "turned on" in config.
"""
if appstats.get('Enabled', False):
cherrypy.Tool._setup(self)

View File

@@ -94,8 +94,8 @@ def validate_etags(autotags=False, debug=False):
def validate_since():
"""Validate the current Last-Modified against If-Modified-Since headers.
If no code has set the Last-Modified response header, then no validation
will be performed.
If no code has set the Last-Modified response header, then no
validation will be performed.
"""
response = cherrypy.serving.response
lastmod = response.headers.get('Last-Modified')
@@ -123,9 +123,9 @@ def validate_since():
def allow(methods=None, debug=False):
"""Raise 405 if request.method not in methods (default ['GET', 'HEAD']).
The given methods are case-insensitive, and may be in any order.
If only one method is allowed, you may supply a single string;
if more than one, supply a list of strings.
The given methods are case-insensitive, and may be in any order. If
only one method is allowed, you may supply a single string; if more
than one, supply a list of strings.
Regardless of whether the current method is allowed or not, this
also emits an 'Allow' response header, containing the given methods.
@@ -154,22 +154,23 @@ def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For',
scheme='X-Forwarded-Proto', debug=False):
"""Change the base URL (scheme://host[:port][/path]).
For running a CP server behind Apache, lighttpd, or other HTTP server.
For running a CP server behind Apache, lighttpd, or other HTTP
server.
For Apache and lighttpd, you should leave the 'local' argument at the
default value of 'X-Forwarded-Host'. For Squid, you probably want to set
tools.proxy.local = 'Origin'.
For Apache and lighttpd, you should leave the 'local' argument at
the default value of 'X-Forwarded-Host'. For Squid, you probably
want to set tools.proxy.local = 'Origin'.
If you want the new request.base to include path info (not just the host),
you must explicitly set base to the full base path, and ALSO set 'local'
to '', so that the X-Forwarded-Host request header (which never includes
path info) does not override it. Regardless, the value for 'base' MUST
NOT end in a slash.
If you want the new request.base to include path info (not just the
host), you must explicitly set base to the full base path, and ALSO
set 'local' to '', so that the X-Forwarded-Host request header
(which never includes path info) does not override it. Regardless,
the value for 'base' MUST NOT end in a slash.
cherrypy.request.remote.ip (the IP address of the client) will be
rewritten if the header specified by the 'remote' arg is valid.
By default, 'remote' is set to 'X-Forwarded-For'. If you do not
want to rewrite remote.ip, set the 'remote' arg to an empty string.
rewritten if the header specified by the 'remote' arg is valid. By
default, 'remote' is set to 'X-Forwarded-For'. If you do not want to
rewrite remote.ip, set the 'remote' arg to an empty string.
"""
request = cherrypy.serving.request
@@ -217,8 +218,8 @@ def proxy(base=None, local='X-Forwarded-Host', remote='X-Forwarded-For',
def ignore_headers(headers=('Range',), debug=False):
"""Delete request headers whose field names are included in 'headers'.
This is a useful tool for working behind certain HTTP servers;
for example, Apache duplicates the work that CP does for 'Range'
This is a useful tool for working behind certain HTTP servers; for
example, Apache duplicates the work that CP does for 'Range'
headers, and will doubly-truncate the response.
"""
request = cherrypy.serving.request
@@ -281,7 +282,6 @@ def referer(pattern, accept=True, accept_missing=False, error=403,
class SessionAuth(object):
"""Assert that the user is logged in."""
session_key = 'username'
@@ -319,7 +319,10 @@ Message: %(error_msg)s
</body></html>""") % vars()).encode('utf-8')
def do_login(self, username, password, from_page='..', **kwargs):
"""Login. May raise redirect, or return True if request handled."""
"""Login.
May raise redirect, or return True if request handled.
"""
response = cherrypy.serving.response
error_msg = self.check_username_and_password(username, password)
if error_msg:
@@ -336,7 +339,10 @@ Message: %(error_msg)s
raise cherrypy.HTTPRedirect(from_page or '/')
def do_logout(self, from_page='..', **kwargs):
"""Logout. May raise redirect, or return True if request handled."""
"""Logout.
May raise redirect, or return True if request handled.
"""
sess = cherrypy.session
username = sess.get(self.session_key)
sess[self.session_key] = None
@@ -346,7 +352,9 @@ Message: %(error_msg)s
raise cherrypy.HTTPRedirect(from_page)
def do_check(self):
"""Assert username. Raise redirect, or return True if request handled.
"""Assert username.
Raise redirect, or return True if request handled.
"""
sess = cherrypy.session
request = cherrypy.serving.request
@@ -408,8 +416,7 @@ def session_auth(**kwargs):
Any attribute of the SessionAuth class may be overridden
via a keyword arg to this function:
""" + '\n '.join(
""" + '\n' + '\n '.join(
'{!s}: {!s}'.format(k, type(getattr(SessionAuth, k)).__name__)
for k in dir(SessionAuth)
if not k.startswith('__')
@@ -490,8 +497,8 @@ def trailing_slash(missing=True, extra=False, status=None, debug=False):
def flatten(debug=False):
"""Wrap response.body in a generator that recursively iterates over body.
This allows cherrypy.response.body to consist of 'nested generators';
that is, a set of generators that yield generators.
This allows cherrypy.response.body to consist of 'nested
generators'; that is, a set of generators that yield generators.
"""
def flattener(input):
numchunks = 0
@@ -622,13 +629,15 @@ def autovary(ignore=None, debug=False):
def convert_params(exception=ValueError, error=400):
"""Convert request params based on function annotations, with error handling.
"""Convert request params based on function annotations.
exception
Exception class to catch.
This function also processes errors that are subclasses of ``exception``.
status
The HTTP error code to return to the client on failure.
:param BaseException exception: Exception class to catch.
:type exception: BaseException
:param error: The HTTP status code to return to the client on failure.
:type error: int
"""
request = cherrypy.serving.request
types = request.handler.callable.__annotations__

View File

@@ -261,9 +261,7 @@ class ResponseEncoder:
def prepare_iter(value):
"""
Ensure response body is iterable and resolves to False when empty.
"""
"""Ensure response body is iterable and resolves to False when empty."""
if isinstance(value, text_or_bytes):
# strings get wrapped in a list because iterating over a single
# item list is much faster than iterating over every character
@@ -360,7 +358,6 @@ def gzip(compress_level=5, mime_types=['text/html', 'text/plain'],
* No 'gzip' or 'x-gzip' is present in the Accept-Encoding header
* No 'gzip' or 'x-gzip' with a qvalue > 0 is present
* The 'identity' value is given with a qvalue > 0.
"""
request = cherrypy.serving.request
response = cherrypy.serving.response

View File

@@ -14,7 +14,6 @@ from cherrypy.process.plugins import SimplePlugin
class ReferrerTree(object):
"""An object which gathers all referrers of an object to a given depth."""
peek_length = 40
@@ -132,7 +131,6 @@ def get_context(obj):
class GCRoot(object):
"""A CherryPy page handler for testing reference leaks."""
classes = [

View File

@@ -71,10 +71,10 @@ def protocol_from_http(protocol_str):
def get_ranges(headervalue, content_length):
"""Return a list of (start, stop) indices from a Range header, or None.
Each (start, stop) tuple will be composed of two ints, which are suitable
for use in a slicing operation. That is, the header "Range: bytes=3-6",
if applied against a Python string, is requesting resource[3:7]. This
function will return the list [(3, 7)].
Each (start, stop) tuple will be composed of two ints, which are
suitable for use in a slicing operation. That is, the header "Range:
bytes=3-6", if applied against a Python string, is requesting
resource[3:7]. This function will return the list [(3, 7)].
If this function returns an empty list, you should return HTTP 416.
"""
@@ -127,7 +127,6 @@ def get_ranges(headervalue, content_length):
class HeaderElement(object):
"""An element (with parameters) from an HTTP header's element list."""
def __init__(self, value, params=None):
@@ -169,14 +168,14 @@ q_separator = re.compile(r'; *q *=')
class AcceptElement(HeaderElement):
"""An element (with parameters) from an Accept* header's element list.
AcceptElement objects are comparable; the more-preferred object will be
"less than" the less-preferred object. They are also therefore sortable;
if you sort a list of AcceptElement objects, they will be listed in
priority order; the most preferred value will be first. Yes, it should
have been the other way around, but it's too late to fix now.
AcceptElement objects are comparable; the more-preferred object will
be "less than" the less-preferred object. They are also therefore
sortable; if you sort a list of AcceptElement objects, they will be
listed in priority order; the most preferred value will be first.
Yes, it should have been the other way around, but it's too late to
fix now.
"""
@classmethod
@@ -249,8 +248,7 @@ def header_elements(fieldname, fieldvalue):
def decode_TEXT(value):
r"""
Decode :rfc:`2047` TEXT
r"""Decode :rfc:`2047` TEXT.
>>> decode_TEXT("=?utf-8?q?f=C3=BCr?=") == b'f\xfcr'.decode('latin-1')
True
@@ -265,9 +263,7 @@ def decode_TEXT(value):
def decode_TEXT_maybe(value):
"""
Decode the text but only if '=?' appears in it.
"""
"""Decode the text but only if '=?' appears in it."""
return decode_TEXT(value) if '=?' in value else value
@@ -388,7 +384,6 @@ def parse_query_string(query_string, keep_blank_values=True, encoding='utf-8'):
class CaseInsensitiveDict(jaraco.collections.KeyTransformingDict):
"""A case-insensitive dict subclass.
Each key is changed on entry to title case.
@@ -417,7 +412,6 @@ else:
class HeaderMap(CaseInsensitiveDict):
"""A dict subclass for HTTP request and response headers.
Each key is changed on entry to str(key).title(). This allows headers
@@ -494,7 +488,6 @@ class HeaderMap(CaseInsensitiveDict):
class Host(object):
"""An internet address.
name

View File

@@ -7,22 +7,22 @@ class NeverExpires(object):
class Timer(object):
"""
A simple timer that will indicate when an expiration time has passed.
"""
"""A simple timer that will indicate when an expiration time has passed."""
def __init__(self, expiration):
'Create a timer that expires at `expiration` (UTC datetime)'
self.expiration = expiration
@classmethod
def after(cls, elapsed):
"""
Return a timer that will expire after `elapsed` passes.
"""
return cls(datetime.datetime.utcnow() + elapsed)
"""Return a timer that will expire after `elapsed` passes."""
return cls(
datetime.datetime.now(datetime.timezone.utc) + elapsed,
)
def expired(self):
return datetime.datetime.utcnow() >= self.expiration
return datetime.datetime.now(
datetime.timezone.utc,
) >= self.expiration
class LockTimeout(Exception):
@@ -30,9 +30,7 @@ class LockTimeout(Exception):
class LockChecker(object):
"""
Keep track of the time and detect if a timeout has expired
"""
"""Keep track of the time and detect if a timeout has expired."""
def __init__(self, session_id, timeout):
self.session_id = session_id
if timeout:

View File

@@ -30,7 +30,6 @@ to get a quick sanity-check on overall CP performance. Use the
``--profile`` flag when running the test suite. Then, use the ``serve()``
function to browse the results in a web browser. If you run this
module from the command line, it will call ``serve()`` for you.
"""
import io
@@ -47,7 +46,9 @@ try:
import pstats
def new_func_strip_path(func_name):
"""Make profiler output more readable by adding `__init__` modules' parents
"""Add ``__init__`` modules' parents.
This makes the profiler output more readable.
"""
filename, line, name = func_name
if filename.endswith('__init__.py'):

View File

@@ -27,18 +27,17 @@ from cherrypy._cpcompat import text_or_bytes
class NamespaceSet(dict):
"""A dict of config namespace names and handlers.
Each config entry should begin with a namespace name; the corresponding
namespace handler will be called once for each config entry in that
namespace, and will be passed two arguments: the config key (with the
namespace removed) and the config value.
Each config entry should begin with a namespace name; the
corresponding namespace handler will be called once for each config
entry in that namespace, and will be passed two arguments: the
config key (with the namespace removed) and the config value.
Namespace handlers may be any Python callable; they may also be
context managers, in which case their __enter__
method should return a callable to be used as the handler.
See cherrypy.tools (the Toolbox class) for an example.
context managers, in which case their __enter__ method should return
a callable to be used as the handler. See cherrypy.tools (the
Toolbox class) for an example.
"""
def __call__(self, config):
@@ -48,9 +47,10 @@ class NamespaceSet(dict):
A flat dict, where keys use dots to separate
namespaces, and values are arbitrary.
The first name in each config key is used to look up the corresponding
namespace handler. For example, a config entry of {'tools.gzip.on': v}
will call the 'tools' namespace handler with the args: ('gzip.on', v)
The first name in each config key is used to look up the
corresponding namespace handler. For example, a config entry of
{'tools.gzip.on': v} will call the 'tools' namespace handler
with the args: ('gzip.on', v)
"""
# Separate the given config into namespaces
ns_confs = {}
@@ -103,7 +103,6 @@ class NamespaceSet(dict):
class Config(dict):
"""A dict-like set of configuration data, with defaults and namespaces.
May take a file, filename, or dict.
@@ -167,7 +166,7 @@ class Parser(configparser.ConfigParser):
self._read(fp, filename)
def as_dict(self, raw=False, vars=None):
"""Convert an INI file to a dictionary"""
"""Convert an INI file to a dictionary."""
# Load INI file into a dict
result = {}
for section in self.sections():
@@ -188,7 +187,7 @@ class Parser(configparser.ConfigParser):
def dict_from_file(self, file):
if hasattr(file, 'read'):
self.readfp(file)
self.read_file(file)
else:
self.read(file)
return self.as_dict()

View File

@@ -120,7 +120,6 @@ missing = object()
class Session(object):
"""A CherryPy dict-like Session object (one per request)."""
_id = None
@@ -148,9 +147,11 @@ class Session(object):
to session data."""
loaded = False
"""If True, data has been retrieved from storage.
This should happen automatically on the first attempt to access
session data.
"""
If True, data has been retrieved from storage. This should happen
automatically on the first attempt to access session data."""
clean_thread = None
'Class-level Monitor which calls self.clean_up.'
@@ -165,9 +166,10 @@ class Session(object):
'True if the session requested by the client did not exist.'
regenerated = False
"""True if the application called session.regenerate().
This is not set by internal calls to regenerate the session id.
"""
True if the application called session.regenerate(). This is not set by
internal calls to regenerate the session id."""
debug = False
'If True, log debug information.'
@@ -335,8 +337,9 @@ class Session(object):
def pop(self, key, default=missing):
"""Remove the specified key and return the corresponding value.
If key is not found, default is returned if given,
otherwise KeyError is raised.
If key is not found, default is returned if given, otherwise
KeyError is raised.
"""
if not self.loaded:
self.load()
@@ -351,13 +354,19 @@ class Session(object):
return key in self._data
def get(self, key, default=None):
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
"""D.get(k[,d]) -> D[k] if k in D, else d.
d defaults to None.
"""
if not self.loaded:
self.load()
return self._data.get(key, default)
def update(self, d):
"""D.update(E) -> None. Update D from E: for k in E: D[k] = E[k]."""
"""D.update(E) -> None.
Update D from E: for k in E: D[k] = E[k].
"""
if not self.loaded:
self.load()
self._data.update(d)
@@ -369,7 +378,10 @@ class Session(object):
return self._data.setdefault(key, default)
def clear(self):
"""D.clear() -> None. Remove all items from D."""
"""D.clear() -> None.
Remove all items from D.
"""
if not self.loaded:
self.load()
self._data.clear()
@@ -492,7 +504,8 @@ class FileSession(Session):
"""Set up the storage system for file-based sessions.
This should only be called once per process; this will be done
automatically when using sessions.init (as the built-in Tool does).
automatically when using sessions.init (as the built-in Tool
does).
"""
# The 'storage_path' arg is required for file-based sessions.
kwargs['storage_path'] = os.path.abspath(kwargs['storage_path'])
@@ -616,7 +629,8 @@ class MemcachedSession(Session):
"""Set up the storage system for memcached-based sessions.
This should only be called once per process; this will be done
automatically when using sessions.init (as the built-in Tool does).
automatically when using sessions.init (as the built-in Tool
does).
"""
for k, v in kwargs.items():
setattr(cls, k, v)

View File

@@ -1,19 +1,18 @@
"""Module with helpers for serving static files."""
import mimetypes
import os
import platform
import re
import stat
import mimetypes
import urllib.parse
import unicodedata
import urllib.parse
from email.generator import _make_boundary as make_boundary
from io import UnsupportedOperation
import cherrypy
from cherrypy._cpcompat import ntob
from cherrypy.lib import cptools, httputil, file_generator_limited
from cherrypy.lib import cptools, file_generator_limited, httputil
def _setup_mimetypes():
@@ -57,15 +56,15 @@ def serve_file(path, content_type=None, disposition=None, name=None,
debug=False):
"""Set status, headers, and body in order to serve the given path.
The Content-Type header will be set to the content_type arg, if provided.
If not provided, the Content-Type will be guessed by the file extension
of the 'path' argument.
The Content-Type header will be set to the content_type arg, if
provided. If not provided, the Content-Type will be guessed by the
file extension of the 'path' argument.
If disposition is not None, the Content-Disposition header will be set
to "<disposition>; filename=<name>; filename*=utf-8''<name>"
as described in :rfc:`6266#appendix-D`.
If name is None, it will be set to the basename of path.
If disposition is None, no Content-Disposition header will be written.
If disposition is not None, the Content-Disposition header will be
set to "<disposition>; filename=<name>; filename*=utf-8''<name>" as
described in :rfc:`6266#appendix-D`. If name is None, it will be set
to the basename of path. If disposition is None, no Content-
Disposition header will be written.
"""
response = cherrypy.serving.response
@@ -185,7 +184,10 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None,
def _serve_fileobj(fileobj, content_type, content_length, debug=False):
"""Internal. Set response.body to the given file object, perhaps ranged."""
"""Set ``response.body`` to the given file object, perhaps ranged.
Internal helper.
"""
response = cherrypy.serving.response
# HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code

View File

@@ -31,7 +31,6 @@ _module__file__base = os.getcwd()
class SimplePlugin(object):
"""Plugin base class which auto-subscribes methods for known channels."""
bus = None
@@ -59,7 +58,6 @@ class SimplePlugin(object):
class SignalHandler(object):
"""Register bus channels (and listeners) for system signals.
You can modify what signals your application listens for, and what it does
@@ -171,8 +169,8 @@ class SignalHandler(object):
If the optional 'listener' argument is provided, it will be
subscribed as a listener for the given signal's channel.
If the given signal name or number is not available on the current
platform, ValueError is raised.
If the given signal name or number is not available on the
current platform, ValueError is raised.
"""
if isinstance(signal, text_or_bytes):
signum = getattr(_signal, signal, None)
@@ -218,11 +216,10 @@ except ImportError:
class DropPrivileges(SimplePlugin):
"""Drop privileges. uid/gid arguments not available on Windows.
Special thanks to `Gavin Baker
<http://antonym.org/2005/12/dropping-privileges-in-python.html>`_
<http://antonym.org/2005/12/dropping-privileges-in-python.html>`_.
"""
def __init__(self, bus, umask=None, uid=None, gid=None):
@@ -234,7 +231,10 @@ class DropPrivileges(SimplePlugin):
@property
def uid(self):
"""The uid under which to run. Availability: Unix."""
"""The uid under which to run.
Availability: Unix.
"""
return self._uid
@uid.setter
@@ -250,7 +250,10 @@ class DropPrivileges(SimplePlugin):
@property
def gid(self):
"""The gid under which to run. Availability: Unix."""
"""The gid under which to run.
Availability: Unix.
"""
return self._gid
@gid.setter
@@ -332,7 +335,6 @@ class DropPrivileges(SimplePlugin):
class Daemonizer(SimplePlugin):
"""Daemonize the running script.
Use this with a Web Site Process Bus via::
@@ -423,7 +425,6 @@ class Daemonizer(SimplePlugin):
class PIDFile(SimplePlugin):
"""Maintain a PID file via a WSPBus."""
def __init__(self, bus, pidfile):
@@ -453,12 +454,11 @@ class PIDFile(SimplePlugin):
class PerpetualTimer(threading.Timer):
"""A responsive subclass of threading.Timer whose run() method repeats.
Use this timer only when you really need a very interruptible timer;
this checks its 'finished' condition up to 20 times a second, which can
results in pretty high CPU usage
this checks its 'finished' condition up to 20 times a second, which
can results in pretty high CPU usage
"""
def __init__(self, *args, **kwargs):
@@ -483,14 +483,14 @@ class PerpetualTimer(threading.Timer):
class BackgroundTask(threading.Thread):
"""A subclass of threading.Thread whose run() method repeats.
Use this class for most repeating tasks. It uses time.sleep() to wait
for each interval, which isn't very responsive; that is, even if you call
self.cancel(), you'll have to wait until the sleep() call finishes before
the thread stops. To compensate, it defaults to being daemonic, which means
it won't delay stopping the whole process.
Use this class for most repeating tasks. It uses time.sleep() to
wait for each interval, which isn't very responsive; that is, even
if you call self.cancel(), you'll have to wait until the sleep()
call finishes before the thread stops. To compensate, it defaults to
being daemonic, which means it won't delay stopping the whole
process.
"""
def __init__(self, interval, function, args=[], kwargs={}, bus=None):
@@ -525,7 +525,6 @@ class BackgroundTask(threading.Thread):
class Monitor(SimplePlugin):
"""WSPBus listener to periodically run a callback in its own thread."""
callback = None
@@ -582,7 +581,6 @@ class Monitor(SimplePlugin):
class Autoreloader(Monitor):
"""Monitor which re-executes the process when files change.
This :ref:`plugin<plugins>` restarts the process (via :func:`os.execv`)
@@ -699,20 +697,20 @@ class Autoreloader(Monitor):
class ThreadManager(SimplePlugin):
"""Manager for HTTP request threads.
If you have control over thread creation and destruction, publish to
the 'acquire_thread' and 'release_thread' channels (for each thread).
This will register/unregister the current thread and publish to
'start_thread' and 'stop_thread' listeners in the bus as needed.
the 'acquire_thread' and 'release_thread' channels (for each
thread). This will register/unregister the current thread and
publish to 'start_thread' and 'stop_thread' listeners in the bus as
needed.
If threads are created and destroyed by code you do not control
(e.g., Apache), then, at the beginning of every HTTP request,
publish to 'acquire_thread' only. You should not publish to
'release_thread' in this case, since you do not know whether
the thread will be re-used or not. The bus will call
'stop_thread' listeners for you when it stops.
'release_thread' in this case, since you do not know whether the
thread will be re-used or not. The bus will call 'stop_thread'
listeners for you when it stops.
"""
threads = None

View File

@@ -132,7 +132,6 @@ class Timeouts:
class ServerAdapter(object):
"""Adapter for an HTTP server.
If you need to start more than one HTTP server (to serve on multiple
@@ -188,9 +187,7 @@ class ServerAdapter(object):
@property
def description(self):
"""
A description about where this server is bound.
"""
"""A description about where this server is bound."""
if self.bind_addr is None:
on_what = 'unknown interface (dynamic?)'
elif isinstance(self.bind_addr, tuple):
@@ -292,7 +289,6 @@ class ServerAdapter(object):
class FlupCGIServer(object):
"""Adapter for a flup.server.cgi.WSGIServer."""
def __init__(self, *args, **kwargs):
@@ -316,7 +312,6 @@ class FlupCGIServer(object):
class FlupFCGIServer(object):
"""Adapter for a flup.server.fcgi.WSGIServer."""
def __init__(self, *args, **kwargs):
@@ -362,7 +357,6 @@ class FlupFCGIServer(object):
class FlupSCGIServer(object):
"""Adapter for a flup.server.scgi.WSGIServer."""
def __init__(self, *args, **kwargs):

View File

@@ -1,4 +1,7 @@
"""Windows service. Requires pywin32."""
"""Windows service.
Requires pywin32.
"""
import os
import win32api
@@ -11,7 +14,6 @@ from cherrypy.process import wspbus, plugins
class ConsoleCtrlHandler(plugins.SimplePlugin):
"""A WSPBus plugin for handling Win32 console events (like Ctrl-C)."""
def __init__(self, bus):
@@ -69,10 +71,10 @@ class ConsoleCtrlHandler(plugins.SimplePlugin):
class Win32Bus(wspbus.Bus):
"""A Web Site Process Bus implementation for Win32.
Instead of time.sleep, this bus blocks using native win32event objects.
Instead of time.sleep, this bus blocks using native win32event
objects.
"""
def __init__(self):
@@ -120,7 +122,6 @@ class Win32Bus(wspbus.Bus):
class _ControlCodes(dict):
"""Control codes used to "signal" a service via ControlService.
User-defined control codes are in the range 128-255. We generally use
@@ -152,7 +153,6 @@ def signal_child(service, command):
class PyWebService(win32serviceutil.ServiceFramework):
"""Python Web Service."""
_svc_name_ = 'Python Web Service'

View File

@@ -57,7 +57,6 @@ the new state.::
| \ |
| V V
STARTED <-- STARTING
"""
import atexit
@@ -65,7 +64,7 @@ import atexit
try:
import ctypes
except ImportError:
"""Google AppEngine is shipped without ctypes
"""Google AppEngine is shipped without ctypes.
:seealso: http://stackoverflow.com/a/6523777/70170
"""
@@ -165,8 +164,8 @@ class Bus(object):
All listeners for a given channel are guaranteed to be called even
if others at the same channel fail. Each failure is logged, but
execution proceeds on to the next listener. The only way to stop all
processing from inside a listener is to raise SystemExit and stop the
whole server.
processing from inside a listener is to raise SystemExit and stop
the whole server.
"""
states = states
@@ -312,8 +311,9 @@ class Bus(object):
def restart(self):
"""Restart the process (may close connections).
This method does not restart the process from the calling thread;
instead, it stops the bus and asks the main thread to call execv.
This method does not restart the process from the calling
thread; instead, it stops the bus and asks the main thread to
call execv.
"""
self.execv = True
self.exit()
@@ -327,10 +327,11 @@ class Bus(object):
"""Wait for the EXITING state, KeyboardInterrupt or SystemExit.
This function is intended to be called only by the main thread.
After waiting for the EXITING state, it also waits for all threads
to terminate, and then calls os.execv if self.execv is True. This
design allows another thread to call bus.restart, yet have the main
thread perform the actual execv call (required on some platforms).
After waiting for the EXITING state, it also waits for all
threads to terminate, and then calls os.execv if self.execv is
True. This design allows another thread to call bus.restart, yet
have the main thread perform the actual execv call (required on
some platforms).
"""
try:
self.wait(states.EXITING, interval=interval, channel='main')
@@ -379,13 +380,14 @@ class Bus(object):
def _do_execv(self):
"""Re-execute the current process.
This must be called from the main thread, because certain platforms
(OS X) don't allow execv to be called in a child thread very well.
This must be called from the main thread, because certain
platforms (OS X) don't allow execv to be called in a child
thread very well.
"""
try:
args = self._get_true_argv()
except NotImplementedError:
"""It's probably win32 or GAE"""
"""It's probably win32 or GAE."""
args = [sys.executable] + self._get_interpreter_argv() + sys.argv
self.log('Re-spawning %s' % ' '.join(args))
@@ -472,7 +474,7 @@ class Bus(object):
c_ind = None
if is_module:
"""It's containing `-m -m` sequence of arguments"""
"""It's containing `-m -m` sequence of arguments."""
if is_command and c_ind < m_ind:
"""There's `-c -c` before `-m`"""
raise RuntimeError(
@@ -481,7 +483,7 @@ class Bus(object):
# Survive module argument here
original_module = sys.argv[0]
if not os.access(original_module, os.R_OK):
"""There's no such module exist"""
"""There's no such module exist."""
raise AttributeError(
"{} doesn't seem to be a module "
'accessible by current user'.format(original_module))
@@ -489,12 +491,12 @@ class Bus(object):
# ... and substitute it with the original module path:
_argv.insert(m_ind, original_module)
elif is_command:
"""It's containing just `-c -c` sequence of arguments"""
"""It's containing just `-c -c` sequence of arguments."""
raise RuntimeError(
"Cannot reconstruct command from '-c'. "
'Ref: https://github.com/cherrypy/cherrypy/issues/1545')
except AttributeError:
"""It looks Py_GetArgcArgv is completely absent in some environments
"""It looks Py_GetArgcArgv's completely absent in some environments
It is known, that there's no Py_GetArgcArgv in MS Windows and
``ctypes`` module is completely absent in Google AppEngine
@@ -512,13 +514,13 @@ class Bus(object):
"""Prepend current working dir to PATH environment variable if needed.
If sys.path[0] is an empty string, the interpreter was likely
invoked with -m and the effective path is about to change on
re-exec. Add the current directory to $PYTHONPATH to ensure
that the new process sees the same path.
invoked with -m and the effective path is about to change on re-
exec. Add the current directory to $PYTHONPATH to ensure that
the new process sees the same path.
This issue cannot be addressed in the general case because
Python cannot reliably reconstruct the
original command line (http://bugs.python.org/issue14208).
Python cannot reliably reconstruct the original command line (
http://bugs.python.org/issue14208).
(This idea filched from tornado.autoreload)
"""
@@ -536,10 +538,10 @@ class Bus(object):
"""Set the CLOEXEC flag on all open files (except stdin/out/err).
If self.max_cloexec_files is an integer (the default), then on
platforms which support it, it represents the max open files setting
for the operating system. This function will be called just before
the process is restarted via os.execv() to prevent open files
from persisting into the new process.
platforms which support it, it represents the max open files
setting for the operating system. This function will be called
just before the process is restarted via os.execv() to prevent
open files from persisting into the new process.
Set self.max_cloexec_files to 0 to disable this behavior.
"""
@@ -578,7 +580,10 @@ class Bus(object):
return t
def log(self, msg='', level=20, traceback=False):
"""Log the given message. Append the last traceback if requested."""
"""Log the given message.
Append the last traceback if requested.
"""
if traceback:
msg += '\n' + ''.join(_traceback.format_exception(*sys.exc_info()))
self.publish('log', msg, level)

View File

@@ -9,7 +9,6 @@ Even before any tweaking, this should serve a few demonstration pages.
Change to this directory and run:
cherryd -c site.conf
"""
import cherrypy

View File

@@ -1,8 +1,3 @@
#!/usr/bin/env python
# -*- coding: iso-8859-1 -*-
# Documentation is intended to be processed by Epydoc.
"""
Introduction
============
@@ -11,266 +6,10 @@ The Munkres module provides an implementation of the Munkres algorithm
(also called the Hungarian algorithm or the Kuhn-Munkres algorithm),
useful for solving the Assignment Problem.
Assignment Problem
==================
Let *C* be an *n*\ x\ *n* matrix representing the costs of each of *n* workers
to perform any of *n* jobs. The assignment problem is to assign jobs to
workers in a way that minimizes the total cost. Since each worker can perform
only one job and each job can be assigned to only one worker the assignments
represent an independent set of the matrix *C*.
One way to generate the optimal set is to create all permutations of
the indexes necessary to traverse the matrix so that no row and column
are used more than once. For instance, given this matrix (expressed in
Python)::
matrix = [[5, 9, 1],
[10, 3, 2],
[8, 7, 4]]
You could use this code to generate the traversal indexes::
def permute(a, results):
if len(a) == 1:
results.insert(len(results), a)
else:
for i in range(0, len(a)):
element = a[i]
a_copy = [a[j] for j in range(0, len(a)) if j != i]
subresults = []
permute(a_copy, subresults)
for subresult in subresults:
result = [element] + subresult
results.insert(len(results), result)
results = []
permute(range(len(matrix)), results) # [0, 1, 2] for a 3x3 matrix
After the call to permute(), the results matrix would look like this::
[[0, 1, 2],
[0, 2, 1],
[1, 0, 2],
[1, 2, 0],
[2, 0, 1],
[2, 1, 0]]
You could then use that index matrix to loop over the original cost matrix
and calculate the smallest cost of the combinations::
n = len(matrix)
minval = sys.maxsize
for row in range(n):
cost = 0
for col in range(n):
cost += matrix[row][col]
minval = min(cost, minval)
print minval
While this approach works fine for small matrices, it does not scale. It
executes in O(*n*!) time: Calculating the permutations for an *n*\ x\ *n*
matrix requires *n*! operations. For a 12x12 matrix, that's 479,001,600
traversals. Even if you could manage to perform each traversal in just one
millisecond, it would still take more than 133 hours to perform the entire
traversal. A 20x20 matrix would take 2,432,902,008,176,640,000 operations. At
an optimistic millisecond per operation, that's more than 77 million years.
The Munkres algorithm runs in O(*n*\ ^3) time, rather than O(*n*!). This
package provides an implementation of that algorithm.
This version is based on
http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html.
This version was written for Python by Brian Clapper from the (Ada) algorithm
at the above web site. (The ``Algorithm::Munkres`` Perl version, in CPAN, was
clearly adapted from the same web site.)
Usage
=====
Construct a Munkres object::
from munkres import Munkres
m = Munkres()
Then use it to compute the lowest cost assignment from a cost matrix. Here's
a sample program::
from munkres import Munkres, print_matrix
matrix = [[5, 9, 1],
[10, 3, 2],
[8, 7, 4]]
m = Munkres()
indexes = m.compute(matrix)
print_matrix(matrix, msg='Lowest cost through this matrix:')
total = 0
for row, column in indexes:
value = matrix[row][column]
total += value
print '(%d, %d) -> %d' % (row, column, value)
print 'total cost: %d' % total
Running that program produces::
Lowest cost through this matrix:
[5, 9, 1]
[10, 3, 2]
[8, 7, 4]
(0, 0) -> 5
(1, 1) -> 3
(2, 2) -> 4
total cost=12
The instantiated Munkres object can be used multiple times on different
matrices.
Non-square Cost Matrices
========================
The Munkres algorithm assumes that the cost matrix is square. However, it's
possible to use a rectangular matrix if you first pad it with 0 values to make
it square. This module automatically pads rectangular cost matrices to make
them square.
Notes:
- The module operates on a *copy* of the caller's matrix, so any padding will
not be seen by the caller.
- The cost matrix must be rectangular or square. An irregular matrix will
*not* work.
Calculating Profit, Rather than Cost
====================================
The cost matrix is just that: A cost matrix. The Munkres algorithm finds
the combination of elements (one from each row and column) that results in
the smallest cost. It's also possible to use the algorithm to maximize
profit. To do that, however, you have to convert your profit matrix to a
cost matrix. The simplest way to do that is to subtract all elements from a
large value. For example::
from munkres import Munkres, print_matrix
matrix = [[5, 9, 1],
[10, 3, 2],
[8, 7, 4]]
cost_matrix = []
for row in matrix:
cost_row = []
for col in row:
cost_row += [sys.maxsize - col]
cost_matrix += [cost_row]
m = Munkres()
indexes = m.compute(cost_matrix)
print_matrix(matrix, msg='Highest profit through this matrix:')
total = 0
for row, column in indexes:
value = matrix[row][column]
total += value
print '(%d, %d) -> %d' % (row, column, value)
print 'total profit=%d' % total
Running that program produces::
Highest profit through this matrix:
[5, 9, 1]
[10, 3, 2]
[8, 7, 4]
(0, 1) -> 9
(1, 0) -> 10
(2, 2) -> 4
total profit=23
The ``munkres`` module provides a convenience method for creating a cost
matrix from a profit matrix. Since it doesn't know whether the matrix contains
floating point numbers, decimals, or integers, you have to provide the
conversion function; but the convenience method takes care of the actual
creation of the cost matrix::
import munkres
cost_matrix = munkres.make_cost_matrix(matrix,
lambda cost: sys.maxsize - cost)
So, the above profit-calculation program can be recast as::
from munkres import Munkres, print_matrix, make_cost_matrix
matrix = [[5, 9, 1],
[10, 3, 2],
[8, 7, 4]]
cost_matrix = make_cost_matrix(matrix, lambda cost: sys.maxsize - cost)
m = Munkres()
indexes = m.compute(cost_matrix)
print_matrix(matrix, msg='Lowest cost through this matrix:')
total = 0
for row, column in indexes:
value = matrix[row][column]
total += value
print '(%d, %d) -> %d' % (row, column, value)
print 'total profit=%d' % total
References
==========
1. http://www.public.iastate.edu/~ddoty/HungarianAlgorithm.html
2. Harold W. Kuhn. The Hungarian Method for the assignment problem.
*Naval Research Logistics Quarterly*, 2:83-97, 1955.
3. Harold W. Kuhn. Variants of the Hungarian method for assignment
problems. *Naval Research Logistics Quarterly*, 3: 253-258, 1956.
4. Munkres, J. Algorithms for the Assignment and Transportation Problems.
*Journal of the Society of Industrial and Applied Mathematics*,
5(1):32-38, March, 1957.
5. http://en.wikipedia.org/wiki/Hungarian_algorithm
Copyright and License
=====================
This software is released under a BSD license, adapted from
<http://opensource.org/licenses/bsd-license.php>
Copyright (c) 2008 Brian M. Clapper
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name "clapper.org" nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
For complete usage documentation, see: https://software.clapper.org/munkres/
"""
__docformat__ = 'restructuredtext'
__docformat__ = 'markdown'
# ---------------------------------------------------------------------------
# Imports
@@ -278,23 +17,43 @@ __docformat__ = 'restructuredtext'
import sys
import copy
from typing import Union, NewType, Sequence, Tuple, Optional, Callable
# ---------------------------------------------------------------------------
# Exports
# ---------------------------------------------------------------------------
__all__ = ['Munkres', 'make_cost_matrix']
__all__ = ['Munkres', 'make_cost_matrix', 'DISALLOWED']
# ---------------------------------------------------------------------------
# Globals
# ---------------------------------------------------------------------------
AnyNum = NewType('AnyNum', Union[int, float])
Matrix = NewType('Matrix', Sequence[Sequence[AnyNum]])
# Info about the module
__version__ = "1.0.6"
__version__ = "1.1.4"
__author__ = "Brian Clapper, bmc@clapper.org"
__url__ = "http://software.clapper.org/munkres/"
__copyright__ = "(c) 2008 Brian M. Clapper"
__license__ = "BSD-style license"
__url__ = "https://software.clapper.org/munkres/"
__copyright__ = "(c) 2008-2020 Brian M. Clapper"
__license__ = "Apache Software License"
# Constants
class DISALLOWED_OBJ(object):
pass
DISALLOWED = DISALLOWED_OBJ()
DISALLOWED_PRINTVAL = "D"
# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------
class UnsolvableMatrix(Exception):
"""
Exception raised for unsolvable matrices
"""
pass
# ---------------------------------------------------------------------------
# Classes
@@ -317,30 +76,18 @@ class Munkres:
self.marked = None
self.path = None
def make_cost_matrix(profit_matrix, inversion_function):
"""
**DEPRECATED**
Please use the module function ``make_cost_matrix()``.
"""
import munkres
return munkres.make_cost_matrix(profit_matrix, inversion_function)
make_cost_matrix = staticmethod(make_cost_matrix)
def pad_matrix(self, matrix, pad_value=0):
def pad_matrix(self, matrix: Matrix, pad_value: int=0) -> Matrix:
"""
Pad a possibly non-square matrix to make it square.
:Parameters:
matrix : list of lists
matrix to pad
**Parameters**
pad_value : int
value to use to pad the matrix
- `matrix` (list of lists of numbers): matrix to pad
- `pad_value` (`int`): value to use to pad the matrix
:rtype: list of lists
:return: a new, possibly padded, matrix
**Returns**
a new, possibly padded, matrix
"""
max_columns = 0
total_rows = len(matrix)
@@ -356,34 +103,35 @@ class Munkres:
new_row = row[:]
if total_rows > row_len:
# Row too short. Pad it.
new_row += [0] * (total_rows - row_len)
new_row += [pad_value] * (total_rows - row_len)
new_matrix += [new_row]
while len(new_matrix) < total_rows:
new_matrix += [[0] * total_rows]
new_matrix += [[pad_value] * total_rows]
return new_matrix
def compute(self, cost_matrix):
def compute(self, cost_matrix: Matrix) -> Sequence[Tuple[int, int]]:
"""
Compute the indexes for the lowest-cost pairings between rows and
columns in the database. Returns a list of (row, column) tuples
columns in the database. Returns a list of `(row, column)` tuples
that can be used to traverse the matrix.
:Parameters:
cost_matrix : list of lists
The cost matrix. If this cost matrix is not square, it
will be padded with zeros, via a call to ``pad_matrix()``.
(This method does *not* modify the caller's matrix. It
operates on a copy of the matrix.)
**WARNING**: This code handles square and rectangular matrices. It
does *not* handle irregular matrices.
**WARNING**: This code handles square and rectangular
matrices. It does *not* handle irregular matrices.
**Parameters**
:rtype: list
:return: A list of ``(row, column)`` tuples that describe the lowest
cost path through the matrix
- `cost_matrix` (list of lists of numbers): The cost matrix. If this
cost matrix is not square, it will be padded with zeros, via a call
to `pad_matrix()`. (This method does *not* modify the caller's
matrix. It operates on a copy of the matrix.)
**Returns**
A list of `(row, column)` tuples that describe the lowest cost path
through the matrix
"""
self.C = self.pad_matrix(cost_matrix)
self.n = len(self.C)
@@ -422,18 +170,18 @@ class Munkres:
return results
def __copy_matrix(self, matrix):
def __copy_matrix(self, matrix: Matrix) -> Matrix:
"""Return an exact copy of the supplied matrix"""
return copy.deepcopy(matrix)
def __make_matrix(self, n, val):
def __make_matrix(self, n: int, val: AnyNum) -> Matrix:
"""Create an *n*x*n* matrix, populating it with the specific value."""
matrix = []
for i in range(n):
matrix += [[val for j in range(n)]]
return matrix
def __step1(self):
def __step1(self) -> int:
"""
For each row of the matrix, find the smallest element and
subtract it from every element in its row. Go to Step 2.
@@ -441,15 +189,22 @@ class Munkres:
C = self.C
n = self.n
for i in range(n):
minval = min(self.C[i])
vals = [x for x in self.C[i] if x is not DISALLOWED]
if len(vals) == 0:
# All values in this row are DISALLOWED. This matrix is
# unsolvable.
raise UnsolvableMatrix(
"Row {0} is entirely DISALLOWED.".format(i)
)
minval = min(vals)
# Find the minimum value for this row and subtract that minimum
# from every element in the row.
for j in range(n):
self.C[i][j] -= minval
if self.C[i][j] is not DISALLOWED:
self.C[i][j] -= minval
return 2
def __step2(self):
def __step2(self) -> int:
"""
Find a zero (Z) in the resulting matrix. If there is no starred
zero in its row or column, star Z. Repeat for each element in the
@@ -464,11 +219,12 @@ class Munkres:
self.marked[i][j] = 1
self.col_covered[j] = True
self.row_covered[i] = True
break
self.__clear_covers()
return 3
def __step3(self):
def __step3(self) -> int:
"""
Cover each column containing a starred zero. If K columns are
covered, the starred zeros describe a complete set of unique
@@ -478,7 +234,7 @@ class Munkres:
count = 0
for i in range(n):
for j in range(n):
if self.marked[i][j] == 1:
if self.marked[i][j] == 1 and not self.col_covered[j]:
self.col_covered[j] = True
count += 1
@@ -489,7 +245,7 @@ class Munkres:
return step
def __step4(self):
def __step4(self) -> int:
"""
Find a noncovered zero and prime it. If there is no starred zero
in the row containing this primed zero, Go to Step 5. Otherwise,
@@ -499,11 +255,11 @@ class Munkres:
"""
step = 0
done = False
row = -1
col = -1
row = 0
col = 0
star_col = -1
while not done:
(row, col) = self.__find_a_zero()
(row, col) = self.__find_a_zero(row, col)
if row < 0:
done = True
step = 6
@@ -522,7 +278,7 @@ class Munkres:
return step
def __step5(self):
def __step5(self) -> int:
"""
Construct a series of alternating primed and starred zeros as
follows. Let Z0 represent the uncovered primed zero found in Step 4.
@@ -558,7 +314,7 @@ class Munkres:
self.__erase_primes()
return 3
def __step6(self):
def __step6(self) -> int:
"""
Add the value found in Step 4 to every element of each covered
row, and subtract it from every element of each uncovered column.
@@ -566,34 +322,44 @@ class Munkres:
lines.
"""
minval = self.__find_smallest()
events = 0 # track actual changes to matrix
for i in range(self.n):
for j in range(self.n):
if self.C[i][j] is DISALLOWED:
continue
if self.row_covered[i]:
self.C[i][j] += minval
events += 1
if not self.col_covered[j]:
self.C[i][j] -= minval
events += 1
if self.row_covered[i] and not self.col_covered[j]:
events -= 2 # change reversed, no real difference
if (events == 0):
raise UnsolvableMatrix("Matrix cannot be solved!")
return 4
def __find_smallest(self):
def __find_smallest(self) -> AnyNum:
"""Find the smallest uncovered value in the matrix."""
minval = sys.maxsize
for i in range(self.n):
for j in range(self.n):
if (not self.row_covered[i]) and (not self.col_covered[j]):
if minval > self.C[i][j]:
if self.C[i][j] is not DISALLOWED and minval > self.C[i][j]:
minval = self.C[i][j]
return minval
def __find_a_zero(self):
def __find_a_zero(self, i0: int = 0, j0: int = 0) -> Tuple[int, int]:
"""Find the first uncovered element with value 0"""
row = -1
col = -1
i = 0
i = i0
n = self.n
done = False
while not done:
j = 0
j = j0
while True:
if (self.C[i][j] == 0) and \
(not self.row_covered[i]) and \
@@ -601,16 +367,16 @@ class Munkres:
row = i
col = j
done = True
j += 1
if j >= n:
j = (j + 1) % n
if j == j0:
break
i += 1
if i >= n:
i = (i + 1) % n
if i == i0:
done = True
return (row, col)
def __find_star_in_row(self, row):
def __find_star_in_row(self, row: Sequence[AnyNum]) -> int:
"""
Find the first starred element in the specified row. Returns
the column index, or -1 if no starred element was found.
@@ -623,7 +389,7 @@ class Munkres:
return col
def __find_star_in_col(self, col):
def __find_star_in_col(self, col: Sequence[AnyNum]) -> int:
"""
Find the first starred element in the specified row. Returns
the row index, or -1 if no starred element was found.
@@ -636,7 +402,7 @@ class Munkres:
return row
def __find_prime_in_row(self, row):
def __find_prime_in_row(self, row) -> int:
"""
Find the first prime element in the specified row. Returns
the column index, or -1 if no starred element was found.
@@ -649,20 +415,22 @@ class Munkres:
return col
def __convert_path(self, path, count):
def __convert_path(self,
path: Sequence[Sequence[int]],
count: int) -> None:
for i in range(count+1):
if self.marked[path[i][0]][path[i][1]] == 1:
self.marked[path[i][0]][path[i][1]] = 0
else:
self.marked[path[i][0]][path[i][1]] = 1
def __clear_covers(self):
def __clear_covers(self) -> None:
"""Clear all covered matrix cells"""
for i in range(self.n):
self.row_covered[i] = False
self.col_covered[i] = False
def __erase_primes(self):
def __erase_primes(self) -> None:
"""Erase all prime markings"""
for i in range(self.n):
for j in range(self.n):
@@ -673,51 +441,56 @@ class Munkres:
# Functions
# ---------------------------------------------------------------------------
def make_cost_matrix(profit_matrix, inversion_function):
def make_cost_matrix(
profit_matrix: Matrix,
inversion_function: Optional[Callable[[AnyNum], AnyNum]] = None
) -> Matrix:
"""
Create a cost matrix from a profit matrix by calling
'inversion_function' to invert each value. The inversion
function must take one numeric argument (of any type) and return
another numeric argument which is presumed to be the cost inverse
of the original profit.
Create a cost matrix from a profit matrix by calling `inversion_function()`
to invert each value. The inversion function must take one numeric argument
(of any type) and return another numeric argument which is presumed to be
the cost inverse of the original profit value. If the inversion function
is not provided, a given cell's inverted value is calculated as
`max(matrix) - value`.
This is a static method. Call it like this:
.. python::
from munkres import Munkres
cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func)
For example:
.. python::
from munkres import Munkres
cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxsize - x)
:Parameters:
profit_matrix : list of lists
The matrix to convert from a profit to a cost matrix
**Parameters**
inversion_function : function
The function to use to invert each entry in the profit matrix
- `profit_matrix` (list of lists of numbers): The matrix to convert from
profit to cost values.
- `inversion_function` (`function`): The function to use to invert each
entry in the profit matrix.
:rtype: list of lists
:return: The converted matrix
**Returns**
A new matrix representing the inversion of `profix_matrix`.
"""
if not inversion_function:
maximum = max(max(row) for row in profit_matrix)
inversion_function = lambda x: maximum - x
cost_matrix = []
for row in profit_matrix:
cost_matrix.append([inversion_function(value) for value in row])
return cost_matrix
def print_matrix(matrix, msg=None):
def print_matrix(matrix: Matrix, msg: Optional[str] = None) -> None:
"""
Convenience function: Displays the contents of a matrix of integers.
Convenience function: Displays the contents of a matrix.
:Parameters:
matrix : list of lists
Matrix to print
**Parameters**
msg : str
Optional message to print before displaying the matrix
- `matrix` (list of lists of numbers): The matrix to print
- `msg` (`str`): Optional message to print before displaying the matrix
"""
import math
@@ -728,16 +501,21 @@ def print_matrix(matrix, msg=None):
width = 0
for row in matrix:
for val in row:
width = max(width, int(math.log10(val)) + 1)
if val is DISALLOWED:
val = DISALLOWED_PRINTVAL
width = max(width, len(str(val)))
# Make the format string
format = '%%%dd' % width
format = ('%%%d' % width)
# Print the matrix
for row in matrix:
sep = '['
for val in row:
sys.stdout.write(sep + format % val)
if val is DISALLOWED:
val = DISALLOWED_PRINTVAL
formatted = ((format + 's') % val)
sys.stdout.write(sep + formatted)
sep = ', '
sys.stdout.write(']\n')
@@ -767,11 +545,51 @@ if __name__ == '__main__':
[9, 7, 4]],
18),
# Square variant with floating point value
([[10.1, 10.2, 8.3],
[9.4, 8.5, 1.6],
[9.7, 7.8, 4.9]],
19.5),
# Rectangular variant
([[10, 10, 8, 11],
[9, 8, 1, 1],
[9, 7, 4, 10]],
15)]
15),
# Rectangular variant with floating point value
([[10.01, 10.02, 8.03, 11.04],
[9.05, 8.06, 1.07, 1.08],
[9.09, 7.1, 4.11, 10.12]],
15.2),
# Rectangular with DISALLOWED
([[4, 5, 6, DISALLOWED],
[1, 9, 12, 11],
[DISALLOWED, 5, 4, DISALLOWED],
[12, 12, 12, 10]],
20),
# Rectangular variant with DISALLOWED and floating point value
([[4.001, 5.002, 6.003, DISALLOWED],
[1.004, 9.005, 12.006, 11.007],
[DISALLOWED, 5.008, 4.009, DISALLOWED],
[12.01, 12.011, 12.012, 10.013]],
20.028),
# DISALLOWED to force pairings
([[1, DISALLOWED, DISALLOWED, DISALLOWED],
[DISALLOWED, 2, DISALLOWED, DISALLOWED],
[DISALLOWED, DISALLOWED, 3, DISALLOWED],
[DISALLOWED, DISALLOWED, DISALLOWED, 4]],
10),
# DISALLOWED to force pairings with floating point value
([[1.1, DISALLOWED, DISALLOWED, DISALLOWED],
[DISALLOWED, 2.2, DISALLOWED, DISALLOWED],
[DISALLOWED, DISALLOWED, 3.3, DISALLOWED],
[DISALLOWED, DISALLOWED, DISALLOWED, 4.4]],
11.0)]
m = Munkres()
for cost_matrix, expected_total in matrices:
@@ -781,6 +599,6 @@ if __name__ == '__main__':
for r, c in indexes:
x = cost_matrix[r][c]
total_cost += x
print(('(%d, %d) -> %d' % (r, c, x)))
print(('lowest cost=%d' % total_cost))
print(('(%d, %d) -> %s' % (r, c, x)))
print(('lowest cost=%s' % total_cost))
assert expected_total == total_cost

File diff suppressed because it is too large Load Diff

View File

@@ -1,608 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2005-2010 ActiveState Software Inc.
# Copyright (c) 2013 Eddy Petrișor
"""Utilities for determining application-specific dirs.
See <http://github.com/ActiveState/appdirs> for details and usage.
"""
# Dev Notes:
# - MSDN on where to store app data files:
# http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
__version_info__ = (1, 4, 3)
__version__ = '.'.join(map(str, __version_info__))
import sys
import os
PY3 = sys.version_info[0] == 3
if PY3:
unicode = str
if sys.platform.startswith('java'):
import platform
os_name = platform.java_ver()[3][0]
if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc.
system = 'win32'
elif os_name.startswith('Mac'): # "Mac OS X", etc.
system = 'darwin'
else: # "Linux", "SunOS", "FreeBSD", etc.
# Setting this to "linux2" is not ideal, but only Windows or Mac
# are actually checked for and the rest of the module expects
# *sys.platform* style strings.
system = 'linux2'
else:
system = sys.platform
def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user data directories are:
Mac OS X: ~/Library/Application Support/<AppName>
Unix: ~/.local/share/<AppName> # or in $XDG_DATA_HOME, if defined
Win XP (not roaming): C:\Documents and Settings\<username>\Application Data\<AppAuthor>\<AppName>
Win XP (roaming): C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>
Win 7 (not roaming): C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>
Win 7 (roaming): C:\Users\<username>\AppData\Roaming\<AppAuthor>\<AppName>
For Unix, we follow the XDG spec and support $XDG_DATA_HOME.
That means, by default "~/.local/share/<AppName>".
"""
if system == "win32":
if appauthor is None:
appauthor = appname
const = roaming and "CSIDL_APPDATA" or "CSIDL_LOCAL_APPDATA"
path = os.path.normpath(_get_win_folder(const))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('~/Library/Application Support/')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_DATA_HOME', os.path.expanduser("~/.local/share"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of data dirs should be
returned. By default, the first item from XDG_DATA_DIRS is
returned, or '/usr/local/share/<AppName>',
if XDG_DATA_DIRS is not set
Typical site data directories are:
Mac OS X: /Library/Application Support/<AppName>
Unix: /usr/local/share/<AppName> or /usr/share/<AppName>
Win XP: C:\Documents and Settings\All Users\Application Data\<AppAuthor>\<AppName>
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
Win 7: C:\ProgramData\<AppAuthor>\<AppName> # Hidden, but writeable on Win 7.
For Unix, this is using the $XDG_DATA_DIRS[0] default.
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_COMMON_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
elif system == 'darwin':
path = os.path.expanduser('/Library/Application Support')
if appname:
path = os.path.join(path, appname)
else:
# XDG default for $XDG_DATA_DIRS
# only first, if multipath is False
path = os.getenv('XDG_DATA_DIRS',
os.pathsep.join(['/usr/local/share', '/usr/share']))
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
if appname and version:
path = os.path.join(path, version)
return path
def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific config dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user config directories are:
Mac OS X: same as user_data_dir
Unix: ~/.config/<AppName> # or in $XDG_CONFIG_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow the XDG spec and support $XDG_CONFIG_HOME.
That means, by default "~/.config/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv('XDG_CONFIG_HOME', os.path.expanduser("~/.config"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
r"""Return full path to the user-shared data dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"multipath" is an optional parameter only applicable to *nix
which indicates that the entire list of config dirs should be
returned. By default, the first item from XDG_CONFIG_DIRS is
returned, or '/etc/xdg/<AppName>', if XDG_CONFIG_DIRS is not set
Typical site config directories are:
Mac OS X: same as site_data_dir
Unix: /etc/xdg/<AppName> or $XDG_CONFIG_DIRS[i]/<AppName> for each value in
$XDG_CONFIG_DIRS
Win *: same as site_data_dir
Vista: (Fail! "C:\ProgramData" is a hidden *system* directory on Vista.)
For Unix, this is using the $XDG_CONFIG_DIRS[0] default, if multipath=False
WARNING: Do not use this on Windows. See the Vista-Fail note above for why.
"""
if system in ["win32", "darwin"]:
path = site_data_dir(appname, appauthor)
if appname and version:
path = os.path.join(path, version)
else:
# XDG default for $XDG_CONFIG_DIRS
# only first, if multipath is False
path = os.getenv('XDG_CONFIG_DIRS', '/etc/xdg')
pathlist = [os.path.expanduser(x.rstrip(os.sep)) for x in path.split(os.pathsep)]
if appname:
if version:
appname = os.path.join(appname, version)
pathlist = [os.sep.join([x, appname]) for x in pathlist]
if multipath:
path = os.pathsep.join(pathlist)
else:
path = pathlist[0]
return path
def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific cache dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Cache" to the base app data dir for Windows. See
discussion below.
Typical user cache directories are:
Mac OS X: ~/Library/Caches/<AppName>
Unix: ~/.cache/<AppName> (XDG default)
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Cache
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Cache
On Windows the only suggestion in the MSDN docs is that local settings go in
the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming
app data dir (the default returned by `user_data_dir` above). Apps typically
put cache data somewhere *under* the given dir here. Some examples:
...\Mozilla\Firefox\Profiles\<ProfileName>\Cache
...\Acme\SuperApp\Cache\1.0
OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value.
This can be disabled with the `opinion=False` option.
"""
if system == "win32":
if appauthor is None:
appauthor = appname
path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA"))
if appname:
if appauthor is not False:
path = os.path.join(path, appauthor, appname)
else:
path = os.path.join(path, appname)
if opinion:
path = os.path.join(path, "Cache")
elif system == 'darwin':
path = os.path.expanduser('~/Library/Caches')
if appname:
path = os.path.join(path, appname)
else:
path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_state_dir(appname=None, appauthor=None, version=None, roaming=False):
r"""Return full path to the user-specific state dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"roaming" (boolean, default False) can be set True to use the Windows
roaming appdata directory. That means that for users on a Windows
network setup for roaming profiles, this user data will be
sync'd on login. See
<http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>
for a discussion of issues.
Typical user state directories are:
Mac OS X: same as user_data_dir
Unix: ~/.local/state/<AppName> # or in $XDG_STATE_HOME, if defined
Win *: same as user_data_dir
For Unix, we follow this Debian proposal <https://wiki.debian.org/XDGBaseDirectorySpecification#state>
to extend the XDG spec and support $XDG_STATE_HOME.
That means, by default "~/.local/state/<AppName>".
"""
if system in ["win32", "darwin"]:
path = user_data_dir(appname, appauthor, None, roaming)
else:
path = os.getenv('XDG_STATE_HOME', os.path.expanduser("~/.local/state"))
if appname:
path = os.path.join(path, appname)
if appname and version:
path = os.path.join(path, version)
return path
def user_log_dir(appname=None, appauthor=None, version=None, opinion=True):
r"""Return full path to the user-specific log dir for this application.
"appname" is the name of application.
If None, just the system directory is returned.
"appauthor" (only used on Windows) is the name of the
appauthor or distributing body for this application. Typically
it is the owning company name. This falls back to appname. You may
pass False to disable it.
"version" is an optional version path element to append to the
path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this
would typically be "<major>.<minor>".
Only applied when appname is present.
"opinion" (boolean) can be False to disable the appending of
"Logs" to the base app data dir for Windows, and "log" to the
base cache dir for Unix. See discussion below.
Typical user log directories are:
Mac OS X: ~/Library/Logs/<AppName>
Unix: ~/.cache/<AppName>/log # or under $XDG_CACHE_HOME if defined
Win XP: C:\Documents and Settings\<username>\Local Settings\Application Data\<AppAuthor>\<AppName>\Logs
Vista: C:\Users\<username>\AppData\Local\<AppAuthor>\<AppName>\Logs
On Windows the only suggestion in the MSDN docs is that local settings
go in the `CSIDL_LOCAL_APPDATA` directory. (Note: I'm interested in
examples of what some windows apps use for a logs dir.)
OPINION: This function appends "Logs" to the `CSIDL_LOCAL_APPDATA`
value for Windows and appends "log" to the user cache dir for Unix.
This can be disabled with the `opinion=False` option.
"""
if system == "darwin":
path = os.path.join(
os.path.expanduser('~/Library/Logs'),
appname)
elif system == "win32":
path = user_data_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "Logs")
else:
path = user_cache_dir(appname, appauthor, version)
version = False
if opinion:
path = os.path.join(path, "log")
if appname and version:
path = os.path.join(path, version)
return path
class AppDirs(object):
"""Convenience wrapper for getting application dirs."""
def __init__(self, appname=None, appauthor=None, version=None,
roaming=False, multipath=False):
self.appname = appname
self.appauthor = appauthor
self.version = version
self.roaming = roaming
self.multipath = multipath
@property
def user_data_dir(self):
return user_data_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_data_dir(self):
return site_data_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_config_dir(self):
return user_config_dir(self.appname, self.appauthor,
version=self.version, roaming=self.roaming)
@property
def site_config_dir(self):
return site_config_dir(self.appname, self.appauthor,
version=self.version, multipath=self.multipath)
@property
def user_cache_dir(self):
return user_cache_dir(self.appname, self.appauthor,
version=self.version)
@property
def user_state_dir(self):
return user_state_dir(self.appname, self.appauthor,
version=self.version)
@property
def user_log_dir(self):
return user_log_dir(self.appname, self.appauthor,
version=self.version)
#---- internal support stuff
def _get_win_folder_from_registry(csidl_name):
"""This is a fallback technique at best. I'm not sure if using the
registry for this guarantees us the correct answer for all CSIDL_*
names.
"""
if PY3:
import winreg as _winreg
else:
import _winreg
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
"CSIDL_COMMON_APPDATA": "Common AppData",
"CSIDL_LOCAL_APPDATA": "Local AppData",
}[csidl_name]
key = _winreg.OpenKey(
_winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir, type = _winreg.QueryValueEx(key, shell_folder_name)
return dir
def _get_win_folder_with_pywin32(csidl_name):
from win32com.shell import shellcon, shell
dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0)
# Try to make this a unicode path because SHGetFolderPath does
# not return unicode strings when there is unicode data in the
# path.
try:
dir = unicode(dir)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
try:
import win32api
dir = win32api.GetShortPathName(dir)
except ImportError:
pass
except UnicodeError:
pass
return dir
def _get_win_folder_with_ctypes(csidl_name):
import ctypes
csidl_const = {
"CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35,
"CSIDL_LOCAL_APPDATA": 28,
}[csidl_name]
buf = ctypes.create_unicode_buffer(1024)
ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf)
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in buf:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf2 = ctypes.create_unicode_buffer(1024)
if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024):
buf = buf2
return buf.value
def _get_win_folder_with_jna(csidl_name):
import array
from com.sun import jna
from com.sun.jna.platform import win32
buf_size = win32.WinDef.MAX_PATH * 2
buf = array.zeros('c', buf_size)
shell = win32.Shell32.INSTANCE
shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf)
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
# Downgrade to short path name if have highbit chars. See
# <http://bugs.activestate.com/show_bug.cgi?id=85099>.
has_high_char = False
for c in dir:
if ord(c) > 255:
has_high_char = True
break
if has_high_char:
buf = array.zeros('c', buf_size)
kernel = win32.Kernel32.INSTANCE
if kernel.GetShortPathName(dir, buf, buf_size):
dir = jna.Native.toString(buf.tostring()).rstrip("\0")
return dir
if system == "win32":
try:
import win32com.shell
_get_win_folder = _get_win_folder_with_pywin32
except ImportError:
try:
from ctypes import windll
_get_win_folder = _get_win_folder_with_ctypes
except ImportError:
try:
import com.sun.jna
_get_win_folder = _get_win_folder_with_jna
except ImportError:
_get_win_folder = _get_win_folder_from_registry
#---- self test code
if __name__ == "__main__":
appname = "MyApp"
appauthor = "MyCompany"
props = ("user_data_dir",
"user_config_dir",
"user_cache_dir",
"user_state_dir",
"user_log_dir",
"site_data_dir",
"site_config_dir")
print("-- app dirs %s --" % __version__)
print("-- app dirs (with optional 'version')")
dirs = AppDirs(appname, appauthor, version="1.0")
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'version')")
dirs = AppDirs(appname, appauthor)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (without optional 'appauthor')")
dirs = AppDirs(appname)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))
print("\n-- app dirs (with disabled 'appauthor')")
dirs = AppDirs(appname, appauthor=False)
for prop in props:
print("%s: %s" % (prop, getattr(dirs, prop)))

View File

@@ -0,0 +1 @@
importlib_resources

View File

@@ -0,0 +1,36 @@
"""Read resources contained within a package."""
from ._common import (
as_file,
files,
Package,
)
from ._legacy import (
contents,
open_binary,
read_binary,
open_text,
read_text,
is_resource,
path,
Resource,
)
from .abc import ResourceReader
__all__ = [
'Package',
'Resource',
'ResourceReader',
'as_file',
'contents',
'files',
'is_resource',
'open_binary',
'open_text',
'path',
'read_binary',
'read_text',
]

View File

@@ -0,0 +1,170 @@
from contextlib import suppress
from io import TextIOWrapper
from . import abc
class SpecLoaderAdapter:
"""
Adapt a package spec to adapt the underlying loader.
"""
def __init__(self, spec, adapter=lambda spec: spec.loader):
self.spec = spec
self.loader = adapter(spec)
def __getattr__(self, name):
return getattr(self.spec, name)
class TraversableResourcesLoader:
"""
Adapt a loader to provide TraversableResources.
"""
def __init__(self, spec):
self.spec = spec
def get_resource_reader(self, name):
return CompatibilityFiles(self.spec)._native()
def _io_wrapper(file, mode='r', *args, **kwargs):
if mode == 'r':
return TextIOWrapper(file, *args, **kwargs)
elif mode == 'rb':
return file
raise ValueError(
"Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode)
)
class CompatibilityFiles:
"""
Adapter for an existing or non-existent resource reader
to provide a compatibility .files().
"""
class SpecPath(abc.Traversable):
"""
Path tied to a module spec.
Can be read and exposes the resource reader children.
"""
def __init__(self, spec, reader):
self._spec = spec
self._reader = reader
def iterdir(self):
if not self._reader:
return iter(())
return iter(
CompatibilityFiles.ChildPath(self._reader, path)
for path in self._reader.contents()
)
def is_file(self):
return False
is_dir = is_file
def joinpath(self, other):
if not self._reader:
return CompatibilityFiles.OrphanPath(other)
return CompatibilityFiles.ChildPath(self._reader, other)
@property
def name(self):
return self._spec.name
def open(self, mode='r', *args, **kwargs):
return _io_wrapper(self._reader.open_resource(None), mode, *args, **kwargs)
class ChildPath(abc.Traversable):
"""
Path tied to a resource reader child.
Can be read but doesn't expose any meaningful children.
"""
def __init__(self, reader, name):
self._reader = reader
self._name = name
def iterdir(self):
return iter(())
def is_file(self):
return self._reader.is_resource(self.name)
def is_dir(self):
return not self.is_file()
def joinpath(self, other):
return CompatibilityFiles.OrphanPath(self.name, other)
@property
def name(self):
return self._name
def open(self, mode='r', *args, **kwargs):
return _io_wrapper(
self._reader.open_resource(self.name), mode, *args, **kwargs
)
class OrphanPath(abc.Traversable):
"""
Orphan path, not tied to a module spec or resource reader.
Can't be read and doesn't expose any meaningful children.
"""
def __init__(self, *path_parts):
if len(path_parts) < 1:
raise ValueError('Need at least one path part to construct a path')
self._path = path_parts
def iterdir(self):
return iter(())
def is_file(self):
return False
is_dir = is_file
def joinpath(self, other):
return CompatibilityFiles.OrphanPath(*self._path, other)
@property
def name(self):
return self._path[-1]
def open(self, mode='r', *args, **kwargs):
raise FileNotFoundError("Can't open orphan path")
def __init__(self, spec):
self.spec = spec
@property
def _reader(self):
with suppress(AttributeError):
return self.spec.loader.get_resource_reader(self.spec.name)
def _native(self):
"""
Return the native reader if it supports files().
"""
reader = self._reader
return reader if hasattr(reader, 'files') else self
def __getattr__(self, attr):
return getattr(self._reader, attr)
def files(self):
return CompatibilityFiles.SpecPath(self.spec, self._reader)
def wrap_spec(package):
"""
Construct a package spec with traversable compatibility
on the spec/loader/reader.
"""
return SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)

View File

@@ -0,0 +1,207 @@
import os
import pathlib
import tempfile
import functools
import contextlib
import types
import importlib
import inspect
import warnings
import itertools
from typing import Union, Optional, cast
from .abc import ResourceReader, Traversable
from ._compat import wrap_spec
Package = Union[types.ModuleType, str]
Anchor = Package
def package_to_anchor(func):
"""
Replace 'package' parameter as 'anchor' and warn about the change.
Other errors should fall through.
>>> files('a', 'b')
Traceback (most recent call last):
TypeError: files() takes from 0 to 1 positional arguments but 2 were given
"""
undefined = object()
@functools.wraps(func)
def wrapper(anchor=undefined, package=undefined):
if package is not undefined:
if anchor is not undefined:
return func(anchor, package)
warnings.warn(
"First parameter to files is renamed to 'anchor'",
DeprecationWarning,
stacklevel=2,
)
return func(package)
elif anchor is undefined:
return func()
return func(anchor)
return wrapper
@package_to_anchor
def files(anchor: Optional[Anchor] = None) -> Traversable:
"""
Get a Traversable resource for an anchor.
"""
return from_package(resolve(anchor))
def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
"""
Return the package's loader if it's a ResourceReader.
"""
# We can't use
# a issubclass() check here because apparently abc.'s __subclasscheck__()
# hook wants to create a weak reference to the object, but
# zipimport.zipimporter does not support weak references, resulting in a
# TypeError. That seems terrible.
spec = package.__spec__
reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore
if reader is None:
return None
return reader(spec.name) # type: ignore
@functools.singledispatch
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
return cast(types.ModuleType, cand)
@resolve.register
def _(cand: str) -> types.ModuleType:
return importlib.import_module(cand)
@resolve.register
def _(cand: None) -> types.ModuleType:
return resolve(_infer_caller().f_globals['__name__'])
def _infer_caller():
"""
Walk the stack and find the frame of the first caller not in this module.
"""
def is_this_file(frame_info):
return frame_info.filename == __file__
def is_wrapper(frame_info):
return frame_info.function == 'wrapper'
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
# also exclude 'wrapper' due to singledispatch in the call stack
callers = itertools.filterfalse(is_wrapper, not_this_file)
return next(callers).frame
def from_package(package: types.ModuleType):
"""
Return a Traversable object for the given package.
"""
spec = wrap_spec(package)
reader = spec.loader.get_resource_reader(spec.name)
return reader.files()
@contextlib.contextmanager
def _tempfile(
reader,
suffix='',
# gh-93353: Keep a reference to call os.remove() in late Python
# finalization.
*,
_os_remove=os.remove,
):
# Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
# blocks due to the need to close the temporary file to work on Windows
# properly.
fd, raw_path = tempfile.mkstemp(suffix=suffix)
try:
try:
os.write(fd, reader())
finally:
os.close(fd)
del reader
yield pathlib.Path(raw_path)
finally:
try:
_os_remove(raw_path)
except FileNotFoundError:
pass
def _temp_file(path):
return _tempfile(path.read_bytes, suffix=path.name)
def _is_present_dir(path: Traversable) -> bool:
"""
Some Traversables implement ``is_dir()`` to raise an
exception (i.e. ``FileNotFoundError``) when the
directory doesn't exist. This function wraps that call
to always return a boolean and only return True
if there's a dir and it exists.
"""
with contextlib.suppress(FileNotFoundError):
return path.is_dir()
return False
@functools.singledispatch
def as_file(path):
"""
Given a Traversable object, return that object as a
path on the local file system in a context manager.
"""
return _temp_dir(path) if _is_present_dir(path) else _temp_file(path)
@as_file.register(pathlib.Path)
@contextlib.contextmanager
def _(path):
"""
Degenerate behavior for pathlib.Path objects.
"""
yield path
@contextlib.contextmanager
def _temp_path(dir: tempfile.TemporaryDirectory):
"""
Wrap tempfile.TemporyDirectory to return a pathlib object.
"""
with dir as result:
yield pathlib.Path(result)
@contextlib.contextmanager
def _temp_dir(path):
"""
Given a traversable dir, recursively replicate the whole tree
to the file system in a context manager.
"""
assert path.is_dir()
with _temp_path(tempfile.TemporaryDirectory()) as temp_dir:
yield _write_contents(temp_dir, path)
def _write_contents(target, source):
child = target.joinpath(source.name)
if source.is_dir():
child.mkdir()
for item in source.iterdir():
_write_contents(child, item)
else:
child.write_bytes(source.read_bytes())
return child

View File

@@ -0,0 +1,108 @@
# flake8: noqa
import abc
import os
import sys
import pathlib
from contextlib import suppress
from typing import Union
if sys.version_info >= (3, 10):
from zipfile import Path as ZipPath # type: ignore
else:
from ..zipp import Path as ZipPath # type: ignore
try:
from typing import runtime_checkable # type: ignore
except ImportError:
def runtime_checkable(cls): # type: ignore
return cls
try:
from typing import Protocol # type: ignore
except ImportError:
Protocol = abc.ABC # type: ignore
class TraversableResourcesLoader:
"""
Adapt loaders to provide TraversableResources and other
compatibility.
Used primarily for Python 3.9 and earlier where the native
loaders do not yet implement TraversableResources.
"""
def __init__(self, spec):
self.spec = spec
@property
def path(self):
return self.spec.origin
def get_resource_reader(self, name):
from . import readers, _adapters
def _zip_reader(spec):
with suppress(AttributeError):
return readers.ZipReader(spec.loader, spec.name)
def _namespace_reader(spec):
with suppress(AttributeError, ValueError):
return readers.NamespaceReader(spec.submodule_search_locations)
def _available_reader(spec):
with suppress(AttributeError):
return spec.loader.get_resource_reader(spec.name)
def _native_reader(spec):
reader = _available_reader(spec)
return reader if hasattr(reader, 'files') else None
def _file_reader(spec):
try:
path = pathlib.Path(self.path)
except TypeError:
return None
if path.exists():
return readers.FileReader(self)
return (
# native reader if it supplies 'files'
_native_reader(self.spec)
or
# local ZipReader if a zip module
_zip_reader(self.spec)
or
# local NamespaceReader if a namespace module
_namespace_reader(self.spec)
or
# local FileReader
_file_reader(self.spec)
# fallback - adapt the spec ResourceReader to TraversableReader
or _adapters.CompatibilityFiles(self.spec)
)
def wrap_spec(package):
"""
Construct a package spec with traversable compatibility
on the spec/loader/reader.
Supersedes _adapters.wrap_spec to use TraversableResourcesLoader
from above for older Python compatibility (<3.10).
"""
from . import _adapters
return _adapters.SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader)
if sys.version_info >= (3, 9):
StrPath = Union[str, os.PathLike[str]]
else:
# PathLike is only subscriptable at runtime in 3.9+
StrPath = Union[str, "os.PathLike[str]"]

View File

@@ -0,0 +1,35 @@
from itertools import filterfalse
from typing import (
Callable,
Iterable,
Iterator,
Optional,
Set,
TypeVar,
Union,
)
# Type and type variable definitions
_T = TypeVar('_T')
_U = TypeVar('_U')
def unique_everseen(
iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None
) -> Iterator[_T]:
"List unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') --> A B C D
# unique_everseen('ABBCcAD', str.lower) --> A B C D
seen: Set[Union[_T, _U]] = set()
seen_add = seen.add
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element

View File

@@ -0,0 +1,120 @@
import functools
import os
import pathlib
import types
import warnings
from typing import Union, Iterable, ContextManager, BinaryIO, TextIO, Any
from . import _common
Package = Union[types.ModuleType, str]
Resource = str
def deprecated(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
warnings.warn(
f"{func.__name__} is deprecated. Use files() instead. "
"Refer to https://importlib-resources.readthedocs.io"
"/en/latest/using.html#migrating-from-legacy for migration advice.",
DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return wrapper
def normalize_path(path: Any) -> str:
"""Normalize a path by ensuring it is a string.
If the resulting string contains path separators, an exception is raised.
"""
str_path = str(path)
parent, file_name = os.path.split(str_path)
if parent:
raise ValueError(f'{path!r} must be only a file name')
return file_name
@deprecated
def open_binary(package: Package, resource: Resource) -> BinaryIO:
"""Return a file-like object opened for binary reading of the resource."""
return (_common.files(package) / normalize_path(resource)).open('rb')
@deprecated
def read_binary(package: Package, resource: Resource) -> bytes:
"""Return the binary contents of the resource."""
return (_common.files(package) / normalize_path(resource)).read_bytes()
@deprecated
def open_text(
package: Package,
resource: Resource,
encoding: str = 'utf-8',
errors: str = 'strict',
) -> TextIO:
"""Return a file-like object opened for text reading of the resource."""
return (_common.files(package) / normalize_path(resource)).open(
'r', encoding=encoding, errors=errors
)
@deprecated
def read_text(
package: Package,
resource: Resource,
encoding: str = 'utf-8',
errors: str = 'strict',
) -> str:
"""Return the decoded string of the resource.
The decoding-related arguments have the same semantics as those of
bytes.decode().
"""
with open_text(package, resource, encoding, errors) as fp:
return fp.read()
@deprecated
def contents(package: Package) -> Iterable[str]:
"""Return an iterable of entries in `package`.
Note that not all entries are resources. Specifically, directories are
not considered resources. Use `is_resource()` on each entry returned here
to check if it is a resource or not.
"""
return [path.name for path in _common.files(package).iterdir()]
@deprecated
def is_resource(package: Package, name: str) -> bool:
"""True if `name` is a resource inside `package`.
Directories are *not* resources.
"""
resource = normalize_path(name)
return any(
traversable.name == resource and traversable.is_file()
for traversable in _common.files(package).iterdir()
)
@deprecated
def path(
package: Package,
resource: Resource,
) -> ContextManager[pathlib.Path]:
"""A context manager providing a file path object to the resource.
If the resource does not already exist on its own on the file system,
a temporary file will be created. If the file was created, the file
will be deleted upon exiting the context manager (no exception is
raised if the file was deleted prior to the context manager
exiting).
"""
return _common.as_file(_common.files(package) / normalize_path(resource))

View File

@@ -0,0 +1,170 @@
import abc
import io
import itertools
import pathlib
from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional
from ._compat import runtime_checkable, Protocol, StrPath
__all__ = ["ResourceReader", "Traversable", "TraversableResources"]
class ResourceReader(metaclass=abc.ABCMeta):
"""Abstract base class for loaders to provide resource reading support."""
@abc.abstractmethod
def open_resource(self, resource: Text) -> BinaryIO:
"""Return an opened, file-like object for binary reading.
The 'resource' argument is expected to represent only a file name.
If the resource cannot be found, FileNotFoundError is raised.
"""
# This deliberately raises FileNotFoundError instead of
# NotImplementedError so that if this method is accidentally called,
# it'll still do the right thing.
raise FileNotFoundError
@abc.abstractmethod
def resource_path(self, resource: Text) -> Text:
"""Return the file system path to the specified resource.
The 'resource' argument is expected to represent only a file name.
If the resource does not exist on the file system, raise
FileNotFoundError.
"""
# This deliberately raises FileNotFoundError instead of
# NotImplementedError so that if this method is accidentally called,
# it'll still do the right thing.
raise FileNotFoundError
@abc.abstractmethod
def is_resource(self, path: Text) -> bool:
"""Return True if the named 'path' is a resource.
Files are resources, directories are not.
"""
raise FileNotFoundError
@abc.abstractmethod
def contents(self) -> Iterable[str]:
"""Return an iterable of entries in `package`."""
raise FileNotFoundError
class TraversalError(Exception):
pass
@runtime_checkable
class Traversable(Protocol):
"""
An object with a subset of pathlib.Path methods suitable for
traversing directories and opening files.
Any exceptions that occur when accessing the backing resource
may propagate unaltered.
"""
@abc.abstractmethod
def iterdir(self) -> Iterator["Traversable"]:
"""
Yield Traversable objects in self
"""
def read_bytes(self) -> bytes:
"""
Read contents of self as bytes
"""
with self.open('rb') as strm:
return strm.read()
def read_text(self, encoding: Optional[str] = None) -> str:
"""
Read contents of self as text
"""
with self.open(encoding=encoding) as strm:
return strm.read()
@abc.abstractmethod
def is_dir(self) -> bool:
"""
Return True if self is a directory
"""
@abc.abstractmethod
def is_file(self) -> bool:
"""
Return True if self is a file
"""
def joinpath(self, *descendants: StrPath) -> "Traversable":
"""
Return Traversable resolved with any descendants applied.
Each descendant should be a path segment relative to self
and each may contain multiple levels separated by
``posixpath.sep`` (``/``).
"""
if not descendants:
return self
names = itertools.chain.from_iterable(
path.parts for path in map(pathlib.PurePosixPath, descendants)
)
target = next(names)
matches = (
traversable for traversable in self.iterdir() if traversable.name == target
)
try:
match = next(matches)
except StopIteration:
raise TraversalError(
"Target not found during traversal.", target, list(names)
)
return match.joinpath(*names)
def __truediv__(self, child: StrPath) -> "Traversable":
"""
Return Traversable child in self
"""
return self.joinpath(child)
@abc.abstractmethod
def open(self, mode='r', *args, **kwargs):
"""
mode may be 'r' or 'rb' to open as text or binary. Return a handle
suitable for reading (same as pathlib.Path.open).
When opening as text, accepts encoding parameters such as those
accepted by io.TextIOWrapper.
"""
@property
@abc.abstractmethod
def name(self) -> str:
"""
The base name of this object without any parent references.
"""
class TraversableResources(ResourceReader):
"""
The required interface for providing traversable
resources.
"""
@abc.abstractmethod
def files(self) -> "Traversable":
"""Return a Traversable object for the loaded package."""
def open_resource(self, resource: StrPath) -> io.BufferedReader:
return self.files().joinpath(resource).open('rb')
def resource_path(self, resource: Any) -> NoReturn:
raise FileNotFoundError(resource)
def is_resource(self, path: StrPath) -> bool:
return self.files().joinpath(path).is_file()
def contents(self) -> Iterator[str]:
return (item.name for item in self.files().iterdir())

View File

@@ -0,0 +1,120 @@
import collections
import pathlib
import operator
from . import abc
from ._itertools import unique_everseen
from ._compat import ZipPath
def remove_duplicates(items):
return iter(collections.OrderedDict.fromkeys(items))
class FileReader(abc.TraversableResources):
def __init__(self, loader):
self.path = pathlib.Path(loader.path).parent
def resource_path(self, resource):
"""
Return the file system path to prevent
`resources.path()` from creating a temporary
copy.
"""
return str(self.path.joinpath(resource))
def files(self):
return self.path
class ZipReader(abc.TraversableResources):
def __init__(self, loader, module):
_, _, name = module.rpartition('.')
self.prefix = loader.prefix.replace('\\', '/') + name + '/'
self.archive = loader.archive
def open_resource(self, resource):
try:
return super().open_resource(resource)
except KeyError as exc:
raise FileNotFoundError(exc.args[0])
def is_resource(self, path):
# workaround for `zipfile.Path.is_file` returning true
# for non-existent paths.
target = self.files().joinpath(path)
return target.is_file() and target.exists()
def files(self):
return ZipPath(self.archive, self.prefix)
class MultiplexedPath(abc.Traversable):
"""
Given a series of Traversable objects, implement a merged
version of the interface across all objects. Useful for
namespace packages which may be multihomed at a single
name.
"""
def __init__(self, *paths):
self._paths = list(map(pathlib.Path, remove_duplicates(paths)))
if not self._paths:
message = 'MultiplexedPath must contain at least one path'
raise FileNotFoundError(message)
if not all(path.is_dir() for path in self._paths):
raise NotADirectoryError('MultiplexedPath only supports directories')
def iterdir(self):
files = (file for path in self._paths for file in path.iterdir())
return unique_everseen(files, key=operator.attrgetter('name'))
def read_bytes(self):
raise FileNotFoundError(f'{self} is not a file')
def read_text(self, *args, **kwargs):
raise FileNotFoundError(f'{self} is not a file')
def is_dir(self):
return True
def is_file(self):
return False
def joinpath(self, *descendants):
try:
return super().joinpath(*descendants)
except abc.TraversalError:
# One of the paths did not resolve (a directory does not exist).
# Just return something that will not exist.
return self._paths[0].joinpath(*descendants)
def open(self, *args, **kwargs):
raise FileNotFoundError(f'{self} is not a file')
@property
def name(self):
return self._paths[0].name
def __repr__(self):
paths = ', '.join(f"'{path}'" for path in self._paths)
return f'MultiplexedPath({paths})'
class NamespaceReader(abc.TraversableResources):
def __init__(self, namespace_path):
if 'NamespacePath' not in str(namespace_path):
raise ValueError('Invalid path')
self.path = MultiplexedPath(*list(namespace_path))
def resource_path(self, resource):
"""
Return the file system path to prevent
`resources.path()` from creating a temporary
copy.
"""
return str(self.path.joinpath(resource))
def files(self):
return self.path

View File

@@ -0,0 +1,106 @@
"""
Interface adapters for low-level readers.
"""
import abc
import io
import itertools
from typing import BinaryIO, List
from .abc import Traversable, TraversableResources
class SimpleReader(abc.ABC):
"""
The minimum, low-level interface required from a resource
provider.
"""
@property
@abc.abstractmethod
def package(self) -> str:
"""
The name of the package for which this reader loads resources.
"""
@abc.abstractmethod
def children(self) -> List['SimpleReader']:
"""
Obtain an iterable of SimpleReader for available
child containers (e.g. directories).
"""
@abc.abstractmethod
def resources(self) -> List[str]:
"""
Obtain available named resources for this virtual package.
"""
@abc.abstractmethod
def open_binary(self, resource: str) -> BinaryIO:
"""
Obtain a File-like for a named resource.
"""
@property
def name(self):
return self.package.split('.')[-1]
class ResourceContainer(Traversable):
"""
Traversable container for a package's resources via its reader.
"""
def __init__(self, reader: SimpleReader):
self.reader = reader
def is_dir(self):
return True
def is_file(self):
return False
def iterdir(self):
files = (ResourceHandle(self, name) for name in self.reader.resources)
dirs = map(ResourceContainer, self.reader.children())
return itertools.chain(files, dirs)
def open(self, *args, **kwargs):
raise IsADirectoryError()
class ResourceHandle(Traversable):
"""
Handle to a named resource in a ResourceReader.
"""
def __init__(self, parent: ResourceContainer, name: str):
self.parent = parent
self.name = name # type: ignore
def is_file(self):
return True
def is_dir(self):
return False
def open(self, mode='r', *args, **kwargs):
stream = self.parent.reader.open_binary(self.name)
if 'b' not in mode:
stream = io.TextIOWrapper(*args, **kwargs)
return stream
def joinpath(self, name):
raise RuntimeError("Cannot traverse into a resource")
class TraversableReader(TraversableResources, SimpleReader):
"""
A TraversableResources based on SimpleReader. Resource providers
may derive from this class to provide the TraversableResources
interface by supplying the SimpleReader interface.
"""
def files(self):
return ResourceContainer(self)

View File

@@ -0,0 +1,32 @@
import os
try:
from test.support import import_helper # type: ignore
except ImportError:
# Python 3.9 and earlier
class import_helper: # type: ignore
from test.support import (
modules_setup,
modules_cleanup,
DirsOnSysPath,
CleanImport,
)
try:
from test.support import os_helper # type: ignore
except ImportError:
# Python 3.9 compat
class os_helper: # type:ignore
from test.support import temp_dir
try:
# Python 3.10
from test.support.os_helper import unlink
except ImportError:
from test.support import unlink as _unlink
def unlink(target):
return _unlink(os.fspath(target))

View File

@@ -0,0 +1,50 @@
import pathlib
import functools
####
# from jaraco.path 3.4
def build(spec, prefix=pathlib.Path()):
"""
Build a set of files/directories, as described by the spec.
Each key represents a pathname, and the value represents
the content. Content may be a nested directory.
>>> spec = {
... 'README.txt': "A README file",
... "foo": {
... "__init__.py": "",
... "bar": {
... "__init__.py": "",
... },
... "baz.py": "# Some code",
... }
... }
>>> tmpdir = getfixture('tmpdir')
>>> build(spec, tmpdir)
"""
for name, contents in spec.items():
create(contents, pathlib.Path(prefix) / name)
@functools.singledispatch
def create(content, path):
path.mkdir(exist_ok=True)
build(content, prefix=path) # type: ignore
@create.register
def _(content: bytes, path):
path.write_bytes(content)
@create.register
def _(content: str, path):
path.write_text(content)
# end from jaraco.path
####

View File

@@ -0,0 +1 @@
one resource

View File

@@ -0,0 +1 @@
two resource

View File

@@ -0,0 +1,102 @@
import io
import unittest
import importlib_resources as resources
from importlib_resources._adapters import (
CompatibilityFiles,
wrap_spec,
)
from . import util
class CompatibilityFilesTests(unittest.TestCase):
@property
def package(self):
bytes_data = io.BytesIO(b'Hello, world!')
return util.create_package(
file=bytes_data,
path='some_path',
contents=('a', 'b', 'c'),
)
@property
def files(self):
return resources.files(self.package)
def test_spec_path_iter(self):
self.assertEqual(
sorted(path.name for path in self.files.iterdir()),
['a', 'b', 'c'],
)
def test_child_path_iter(self):
self.assertEqual(list((self.files / 'a').iterdir()), [])
def test_orphan_path_iter(self):
self.assertEqual(list((self.files / 'a' / 'a').iterdir()), [])
self.assertEqual(list((self.files / 'a' / 'a' / 'a').iterdir()), [])
def test_spec_path_is(self):
self.assertFalse(self.files.is_file())
self.assertFalse(self.files.is_dir())
def test_child_path_is(self):
self.assertTrue((self.files / 'a').is_file())
self.assertFalse((self.files / 'a').is_dir())
def test_orphan_path_is(self):
self.assertFalse((self.files / 'a' / 'a').is_file())
self.assertFalse((self.files / 'a' / 'a').is_dir())
self.assertFalse((self.files / 'a' / 'a' / 'a').is_file())
self.assertFalse((self.files / 'a' / 'a' / 'a').is_dir())
def test_spec_path_name(self):
self.assertEqual(self.files.name, 'testingpackage')
def test_child_path_name(self):
self.assertEqual((self.files / 'a').name, 'a')
def test_orphan_path_name(self):
self.assertEqual((self.files / 'a' / 'b').name, 'b')
self.assertEqual((self.files / 'a' / 'b' / 'c').name, 'c')
def test_spec_path_open(self):
self.assertEqual(self.files.read_bytes(), b'Hello, world!')
self.assertEqual(self.files.read_text(), 'Hello, world!')
def test_child_path_open(self):
self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!')
self.assertEqual((self.files / 'a').read_text(), 'Hello, world!')
def test_orphan_path_open(self):
with self.assertRaises(FileNotFoundError):
(self.files / 'a' / 'b').read_bytes()
with self.assertRaises(FileNotFoundError):
(self.files / 'a' / 'b' / 'c').read_bytes()
def test_open_invalid_mode(self):
with self.assertRaises(ValueError):
self.files.open('0')
def test_orphan_path_invalid(self):
with self.assertRaises(ValueError):
CompatibilityFiles.OrphanPath()
def test_wrap_spec(self):
spec = wrap_spec(self.package)
self.assertIsInstance(spec.loader.get_resource_reader(None), CompatibilityFiles)
class CompatibilityFilesNoReaderTests(unittest.TestCase):
@property
def package(self):
return util.create_package_from_loader(None)
@property
def files(self):
return resources.files(self.package)
def test_spec_path_joinpath(self):
self.assertIsInstance(self.files / 'a', CompatibilityFiles.OrphanPath)

View File

@@ -0,0 +1,43 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
class ContentsTests:
expected = {
'__init__.py',
'binary.file',
'subdirectory',
'utf-16.file',
'utf-8.file',
}
def test_contents(self):
contents = {path.name for path in resources.files(self.data).iterdir()}
assert self.expected <= contents
class ContentsDiskTests(ContentsTests, unittest.TestCase):
def setUp(self):
self.data = data01
class ContentsZipTests(ContentsTests, util.ZipSetup, unittest.TestCase):
pass
class ContentsNamespaceTests(ContentsTests, unittest.TestCase):
expected = {
# no __init__ because of namespace design
# no subdirectory as incidental difference in fixture
'binary.file',
'utf-16.file',
'utf-8.file',
}
def setUp(self):
from . import namespacedata01
self.data = namespacedata01

View File

@@ -0,0 +1,112 @@
import typing
import textwrap
import unittest
import warnings
import importlib
import contextlib
import importlib_resources as resources
from ..abc import Traversable
from . import data01
from . import util
from . import _path
from ._compat import os_helper, import_helper
@contextlib.contextmanager
def suppress_known_deprecation():
with warnings.catch_warnings(record=True) as ctx:
warnings.simplefilter('default', category=DeprecationWarning)
yield ctx
class FilesTests:
def test_read_bytes(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_bytes()
assert actual == b'Hello, UTF-8 world!\n'
def test_read_text(self):
files = resources.files(self.data)
actual = files.joinpath('utf-8.file').read_text(encoding='utf-8')
assert actual == 'Hello, UTF-8 world!\n'
@unittest.skipUnless(
hasattr(typing, 'runtime_checkable'),
"Only suitable when typing supports runtime_checkable",
)
def test_traversable(self):
assert isinstance(resources.files(self.data), Traversable)
def test_old_parameter(self):
"""
Files used to take a 'package' parameter. Make sure anyone
passing by name is still supported.
"""
with suppress_known_deprecation():
resources.files(package=self.data)
class OpenDiskTests(FilesTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase):
pass
class OpenNamespaceTests(FilesTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class SiteDir:
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(os_helper.temp_dir())
self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir))
self.fixtures.enter_context(import_helper.CleanImport())
class ModulesFilesTests(SiteDir, unittest.TestCase):
def test_module_resources(self):
"""
A module can have resources found adjacent to the module.
"""
spec = {
'mod.py': '',
'res.txt': 'resources are the best',
}
_path.build(spec, self.site_dir)
import mod
actual = resources.files(mod).joinpath('res.txt').read_text()
assert actual == spec['res.txt']
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
def test_implicit_files(self):
"""
Without any parameter, files() will infer the location as the caller.
"""
spec = {
'somepkg': {
'__init__.py': textwrap.dedent(
"""
import importlib_resources as res
val = res.files().joinpath('res.txt').read_text()
"""
),
'res.txt': 'resources are the best',
},
}
_path.build(spec, self.site_dir)
assert importlib.import_module('somepkg').val == 'resources are the best'
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,81 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
target = resources.files(package).joinpath(path)
with target.open('rb'):
pass
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
target = resources.files(package).joinpath(path)
with target.open():
pass
class OpenTests:
def test_open_binary(self):
target = resources.files(self.data) / 'binary.file'
with target.open('rb') as fp:
result = fp.read()
self.assertEqual(result, b'\x00\x01\x02\x03')
def test_open_text_default_encoding(self):
target = resources.files(self.data) / 'utf-8.file'
with target.open() as fp:
result = fp.read()
self.assertEqual(result, 'Hello, UTF-8 world!\n')
def test_open_text_given_encoding(self):
target = resources.files(self.data) / 'utf-16.file'
with target.open(encoding='utf-16', errors='strict') as fp:
result = fp.read()
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_open_text_with_errors(self):
# Raises UnicodeError without the 'errors' argument.
target = resources.files(self.data) / 'utf-16.file'
with target.open(encoding='utf-8', errors='strict') as fp:
self.assertRaises(UnicodeError, fp.read)
with target.open(encoding='utf-8', errors='ignore') as fp:
result = fp.read()
self.assertEqual(
result,
'H\x00e\x00l\x00l\x00o\x00,\x00 '
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
)
def test_open_binary_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
self.assertRaises(FileNotFoundError, target.open, 'rb')
def test_open_text_FileNotFoundError(self):
target = resources.files(self.data) / 'does-not-exist'
self.assertRaises(FileNotFoundError, target.open)
class OpenDiskTests(OpenTests, unittest.TestCase):
def setUp(self):
self.data = data01
class OpenDiskNamespaceTests(OpenTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase):
pass
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,64 @@
import io
import unittest
import importlib_resources as resources
from . import data01
from . import util
class CommonTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
with resources.as_file(resources.files(package).joinpath(path)):
pass
class PathTests:
def test_reading(self):
# Path should be readable.
# Test also implicitly verifies the returned object is a pathlib.Path
# instance.
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
self.assertTrue(path.name.endswith("utf-8.file"), repr(path))
# pathlib.Path.read_text() was introduced in Python 3.5.
with path.open('r', encoding='utf-8') as file:
text = file.read()
self.assertEqual('Hello, UTF-8 world!\n', text)
class PathDiskTests(PathTests, unittest.TestCase):
data = data01
def test_natural_path(self):
"""
Guarantee the internal implementation detail that
file-system-backed resources do not get the tempdir
treatment.
"""
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
assert 'data' in str(path)
class PathMemoryTests(PathTests, unittest.TestCase):
def setUp(self):
file = io.BytesIO(b'Hello, UTF-8 world!\n')
self.addCleanup(file.close)
self.data = util.create_package(
file=file, path=FileNotFoundError("package exists only in memory")
)
self.data.__spec__.origin = None
self.data.__spec__.has_location = False
class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase):
def test_remove_in_context_manager(self):
# It is not an error if the file that was temporarily stashed on the
# file system is removed inside the `with` stanza.
target = resources.files(self.data) / 'utf-8.file'
with resources.as_file(target) as path:
path.unlink()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,76 @@
import unittest
import importlib_resources as resources
from . import data01
from . import util
from importlib import import_module
class CommonBinaryTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
resources.files(package).joinpath(path).read_bytes()
class CommonTextTests(util.CommonTests, unittest.TestCase):
def execute(self, package, path):
resources.files(package).joinpath(path).read_text()
class ReadTests:
def test_read_bytes(self):
result = resources.files(self.data).joinpath('binary.file').read_bytes()
self.assertEqual(result, b'\0\1\2\3')
def test_read_text_default_encoding(self):
result = resources.files(self.data).joinpath('utf-8.file').read_text()
self.assertEqual(result, 'Hello, UTF-8 world!\n')
def test_read_text_given_encoding(self):
result = (
resources.files(self.data)
.joinpath('utf-16.file')
.read_text(encoding='utf-16')
)
self.assertEqual(result, 'Hello, UTF-16 world!\n')
def test_read_text_with_errors(self):
# Raises UnicodeError without the 'errors' argument.
target = resources.files(self.data) / 'utf-16.file'
self.assertRaises(UnicodeError, target.read_text, encoding='utf-8')
result = target.read_text(encoding='utf-8', errors='ignore')
self.assertEqual(
result,
'H\x00e\x00l\x00l\x00o\x00,\x00 '
'\x00U\x00T\x00F\x00-\x001\x006\x00 '
'\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00',
)
class ReadDiskTests(ReadTests, unittest.TestCase):
data = data01
class ReadZipTests(ReadTests, util.ZipSetup, unittest.TestCase):
def test_read_submodule_resource(self):
submodule = import_module('ziptestdata.subdirectory')
result = resources.files(submodule).joinpath('binary.file').read_bytes()
self.assertEqual(result, b'\0\1\2\3')
def test_read_submodule_resource_by_name(self):
result = (
resources.files('ziptestdata.subdirectory')
.joinpath('binary.file')
.read_bytes()
)
self.assertEqual(result, b'\0\1\2\3')
class ReadNamespaceTests(ReadTests, unittest.TestCase):
def setUp(self):
from . import namespacedata01
self.data = namespacedata01
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,133 @@
import os.path
import sys
import pathlib
import unittest
from importlib import import_module
from importlib_resources.readers import MultiplexedPath, NamespaceReader
class MultiplexedPathTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
path = pathlib.Path(__file__).parent / 'namespacedata01'
cls.folder = str(path)
def test_init_no_paths(self):
with self.assertRaises(FileNotFoundError):
MultiplexedPath()
def test_init_file(self):
with self.assertRaises(NotADirectoryError):
MultiplexedPath(os.path.join(self.folder, 'binary.file'))
def test_iterdir(self):
contents = {path.name for path in MultiplexedPath(self.folder).iterdir()}
try:
contents.remove('__pycache__')
except (KeyError, ValueError):
pass
self.assertEqual(contents, {'binary.file', 'utf-16.file', 'utf-8.file'})
def test_iterdir_duplicate(self):
data01 = os.path.abspath(os.path.join(__file__, '..', 'data01'))
contents = {
path.name for path in MultiplexedPath(self.folder, data01).iterdir()
}
for remove in ('__pycache__', '__init__.pyc'):
try:
contents.remove(remove)
except (KeyError, ValueError):
pass
self.assertEqual(
contents,
{'__init__.py', 'binary.file', 'subdirectory', 'utf-16.file', 'utf-8.file'},
)
def test_is_dir(self):
self.assertEqual(MultiplexedPath(self.folder).is_dir(), True)
def test_is_file(self):
self.assertEqual(MultiplexedPath(self.folder).is_file(), False)
def test_open_file(self):
path = MultiplexedPath(self.folder)
with self.assertRaises(FileNotFoundError):
path.read_bytes()
with self.assertRaises(FileNotFoundError):
path.read_text()
with self.assertRaises(FileNotFoundError):
path.open()
def test_join_path(self):
prefix = os.path.abspath(os.path.join(__file__, '..'))
data01 = os.path.join(prefix, 'data01')
path = MultiplexedPath(self.folder, data01)
self.assertEqual(
str(path.joinpath('binary.file'))[len(prefix) + 1 :],
os.path.join('namespacedata01', 'binary.file'),
)
self.assertEqual(
str(path.joinpath('subdirectory'))[len(prefix) + 1 :],
os.path.join('data01', 'subdirectory'),
)
self.assertEqual(
str(path.joinpath('imaginary'))[len(prefix) + 1 :],
os.path.join('namespacedata01', 'imaginary'),
)
self.assertEqual(path.joinpath(), path)
def test_join_path_compound(self):
path = MultiplexedPath(self.folder)
assert not path.joinpath('imaginary/foo.py').exists()
def test_repr(self):
self.assertEqual(
repr(MultiplexedPath(self.folder)),
f"MultiplexedPath('{self.folder}')",
)
def test_name(self):
self.assertEqual(
MultiplexedPath(self.folder).name,
os.path.basename(self.folder),
)
class NamespaceReaderTest(unittest.TestCase):
site_dir = str(pathlib.Path(__file__).parent)
@classmethod
def setUpClass(cls):
sys.path.append(cls.site_dir)
@classmethod
def tearDownClass(cls):
sys.path.remove(cls.site_dir)
def test_init_error(self):
with self.assertRaises(ValueError):
NamespaceReader(['path1', 'path2'])
def test_resource_path(self):
namespacedata01 = import_module('namespacedata01')
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
self.assertEqual(
reader.resource_path('binary.file'), os.path.join(root, 'binary.file')
)
self.assertEqual(
reader.resource_path('imaginary'), os.path.join(root, 'imaginary')
)
def test_files(self):
namespacedata01 = import_module('namespacedata01')
reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations)
root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01'))
self.assertIsInstance(reader.files(), MultiplexedPath)
self.assertEqual(repr(reader.files()), f"MultiplexedPath('{root}')")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,260 @@
import sys
import unittest
import importlib_resources as resources
import uuid
import pathlib
from . import data01
from . import zipdata01, zipdata02
from . import util
from importlib import import_module
from ._compat import import_helper, unlink
class ResourceTests:
# Subclasses are expected to set the `data` attribute.
def test_is_file_exists(self):
target = resources.files(self.data) / 'binary.file'
self.assertTrue(target.is_file())
def test_is_file_missing(self):
target = resources.files(self.data) / 'not-a-file'
self.assertFalse(target.is_file())
def test_is_dir(self):
target = resources.files(self.data) / 'subdirectory'
self.assertFalse(target.is_file())
self.assertTrue(target.is_dir())
class ResourceDiskTests(ResourceTests, unittest.TestCase):
def setUp(self):
self.data = data01
class ResourceZipTests(ResourceTests, util.ZipSetup, unittest.TestCase):
pass
def names(traversable):
return {item.name for item in traversable.iterdir()}
class ResourceLoaderTests(unittest.TestCase):
def test_resource_contents(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
)
self.assertEqual(names(resources.files(package)), {'A', 'B', 'C'})
def test_is_file(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertTrue(resources.files(package).joinpath('B').is_file())
def test_is_dir(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertTrue(resources.files(package).joinpath('D').is_dir())
def test_resource_missing(self):
package = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F']
)
self.assertFalse(resources.files(package).joinpath('Z').is_file())
class ResourceCornerCaseTests(unittest.TestCase):
def test_package_has_no_reader_fallback(self):
# Test odd ball packages which:
# 1. Do not have a ResourceReader as a loader
# 2. Are not on the file system
# 3. Are not in a zip file
module = util.create_package(
file=data01, path=data01.__file__, contents=['A', 'B', 'C']
)
# Give the module a dummy loader.
module.__loader__ = object()
# Give the module a dummy origin.
module.__file__ = '/path/which/shall/not/be/named'
module.__spec__.loader = module.__loader__
module.__spec__.origin = module.__file__
self.assertFalse(resources.files(module).joinpath('A').is_file())
class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase):
ZIP_MODULE = zipdata01 # type: ignore
def test_is_submodule_resource(self):
submodule = import_module('ziptestdata.subdirectory')
self.assertTrue(resources.files(submodule).joinpath('binary.file').is_file())
def test_read_submodule_resource_by_name(self):
self.assertTrue(
resources.files('ziptestdata.subdirectory')
.joinpath('binary.file')
.is_file()
)
def test_submodule_contents(self):
submodule = import_module('ziptestdata.subdirectory')
self.assertEqual(
names(resources.files(submodule)), {'__init__.py', 'binary.file'}
)
def test_submodule_contents_by_name(self):
self.assertEqual(
names(resources.files('ziptestdata.subdirectory')),
{'__init__.py', 'binary.file'},
)
def test_as_file_directory(self):
with resources.as_file(resources.files('ziptestdata')) as data:
assert data.name == 'ziptestdata'
assert data.is_dir()
assert data.joinpath('subdirectory').is_dir()
assert len(list(data.iterdir()))
assert not data.parent.exists()
class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase):
ZIP_MODULE = zipdata02 # type: ignore
def test_unrelated_contents(self):
"""
Test thata zip with two unrelated subpackages return
distinct resources. Ref python/importlib_resources#44.
"""
self.assertEqual(
names(resources.files('ziptestdata.one')),
{'__init__.py', 'resource1.txt'},
)
self.assertEqual(
names(resources.files('ziptestdata.two')),
{'__init__.py', 'resource2.txt'},
)
class DeletingZipsTest(unittest.TestCase):
"""Having accessed resources in a zip file should not keep an open
reference to the zip.
"""
ZIP_MODULE = zipdata01
def setUp(self):
modules = import_helper.modules_setup()
self.addCleanup(import_helper.modules_cleanup, *modules)
data_path = pathlib.Path(self.ZIP_MODULE.__file__)
data_dir = data_path.parent
self.source_zip_path = data_dir / 'ziptestdata.zip'
self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute()
self.zip_path.write_bytes(self.source_zip_path.read_bytes())
sys.path.append(str(self.zip_path))
self.data = import_module('ziptestdata')
def tearDown(self):
try:
sys.path.remove(str(self.zip_path))
except ValueError:
pass
try:
del sys.path_importer_cache[str(self.zip_path)]
del sys.modules[self.data.__name__]
except KeyError:
pass
try:
unlink(self.zip_path)
except OSError:
# If the test fails, this will probably fail too
pass
def test_iterdir_does_not_keep_open(self):
c = [item.name for item in resources.files('ziptestdata').iterdir()]
self.zip_path.unlink()
del c
def test_is_file_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('binary.file').is_file()
self.zip_path.unlink()
del c
def test_is_file_failure_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('not-present').is_file()
self.zip_path.unlink()
del c
@unittest.skip("Desired but not supported.")
def test_as_file_does_not_keep_open(self): # pragma: no cover
c = resources.as_file(resources.files('ziptestdata') / 'binary.file')
self.zip_path.unlink()
del c
def test_entered_path_does_not_keep_open(self):
# This is what certifi does on import to make its bundle
# available for the process duration.
c = resources.as_file(
resources.files('ziptestdata') / 'binary.file'
).__enter__()
self.zip_path.unlink()
del c
def test_read_binary_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('binary.file').read_bytes()
self.zip_path.unlink()
del c
def test_read_text_does_not_keep_open(self):
c = resources.files('ziptestdata').joinpath('utf-8.file').read_text()
self.zip_path.unlink()
del c
class ResourceFromNamespaceTest01(unittest.TestCase):
site_dir = str(pathlib.Path(__file__).parent)
@classmethod
def setUpClass(cls):
sys.path.append(cls.site_dir)
@classmethod
def tearDownClass(cls):
sys.path.remove(cls.site_dir)
def test_is_submodule_resource(self):
self.assertTrue(
resources.files(import_module('namespacedata01'))
.joinpath('binary.file')
.is_file()
)
def test_read_submodule_resource_by_name(self):
self.assertTrue(
resources.files('namespacedata01').joinpath('binary.file').is_file()
)
def test_submodule_contents(self):
contents = names(resources.files(import_module('namespacedata01')))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'})
def test_submodule_contents_by_name(self):
contents = names(resources.files('namespacedata01'))
try:
contents.remove('__pycache__')
except KeyError:
pass
self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'})
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,53 @@
"""
Generate the zip test data files.
Run to build the tests/zipdataNN/ziptestdata.zip files from
files in tests/dataNN.
Replaces the file with the working copy, but does commit anything
to the source repo.
"""
import contextlib
import os
import pathlib
import zipfile
def main():
"""
>>> from unittest import mock
>>> monkeypatch = getfixture('monkeypatch')
>>> monkeypatch.setattr(zipfile, 'ZipFile', mock.MagicMock())
>>> print(); main() # print workaround for bpo-32509
<BLANKLINE>
...data01... -> ziptestdata/...
...
...data02... -> ziptestdata/...
...
"""
suffixes = '01', '02'
tuple(map(generate, suffixes))
def generate(suffix):
root = pathlib.Path(__file__).parent.relative_to(os.getcwd())
zfpath = root / f'zipdata{suffix}/ziptestdata.zip'
with zipfile.ZipFile(zfpath, 'w') as zf:
for src, rel in walk(root / f'data{suffix}'):
dst = 'ziptestdata' / pathlib.PurePosixPath(rel.as_posix())
print(src, '->', dst)
zf.write(src, dst)
def walk(datapath):
for dirpath, dirnames, filenames in os.walk(datapath):
with contextlib.suppress(ValueError):
dirnames.remove('__pycache__')
for filename in filenames:
res = pathlib.Path(dirpath) / filename
rel = res.relative_to(datapath)
yield res, rel
__name__ == '__main__' and main()

View File

@@ -0,0 +1,167 @@
import abc
import importlib
import io
import sys
import types
import pathlib
from . import data01
from . import zipdata01
from ..abc import ResourceReader
from ._compat import import_helper
from importlib.machinery import ModuleSpec
class Reader(ResourceReader):
def __init__(self, **kwargs):
vars(self).update(kwargs)
def get_resource_reader(self, package):
return self
def open_resource(self, path):
self._path = path
if isinstance(self.file, Exception):
raise self.file
return self.file
def resource_path(self, path_):
self._path = path_
if isinstance(self.path, Exception):
raise self.path
return self.path
def is_resource(self, path_):
self._path = path_
if isinstance(self.path, Exception):
raise self.path
def part(entry):
return entry.split('/')
return any(
len(parts) == 1 and parts[0] == path_ for parts in map(part, self._contents)
)
def contents(self):
if isinstance(self.path, Exception):
raise self.path
yield from self._contents
def create_package_from_loader(loader, is_package=True):
name = 'testingpackage'
module = types.ModuleType(name)
spec = ModuleSpec(name, loader, origin='does-not-exist', is_package=is_package)
module.__spec__ = spec
module.__loader__ = loader
return module
def create_package(file=None, path=None, is_package=True, contents=()):
return create_package_from_loader(
Reader(file=file, path=path, _contents=contents),
is_package,
)
class CommonTests(metaclass=abc.ABCMeta):
"""
Tests shared by test_open, test_path, and test_read.
"""
@abc.abstractmethod
def execute(self, package, path):
"""
Call the pertinent legacy API function (e.g. open_text, path)
on package and path.
"""
def test_package_name(self):
# Passing in the package name should succeed.
self.execute(data01.__name__, 'utf-8.file')
def test_package_object(self):
# Passing in the package itself should succeed.
self.execute(data01, 'utf-8.file')
def test_string_path(self):
# Passing in a string for the path should succeed.
path = 'utf-8.file'
self.execute(data01, path)
def test_pathlib_path(self):
# Passing in a pathlib.PurePath object for the path should succeed.
path = pathlib.PurePath('utf-8.file')
self.execute(data01, path)
def test_importing_module_as_side_effect(self):
# The anchor package can already be imported.
del sys.modules[data01.__name__]
self.execute(data01.__name__, 'utf-8.file')
def test_missing_path(self):
# Attempting to open or read or request the path for a
# non-existent path should succeed if open_resource
# can return a viable data stream.
bytes_data = io.BytesIO(b'Hello, world!')
package = create_package(file=bytes_data, path=FileNotFoundError())
self.execute(package, 'utf-8.file')
self.assertEqual(package.__loader__._path, 'utf-8.file')
def test_extant_path(self):
# Attempting to open or read or request the path when the
# path does exist should still succeed. Does not assert
# anything about the result.
bytes_data = io.BytesIO(b'Hello, world!')
# any path that exists
path = __file__
package = create_package(file=bytes_data, path=path)
self.execute(package, 'utf-8.file')
self.assertEqual(package.__loader__._path, 'utf-8.file')
def test_useless_loader(self):
package = create_package(file=FileNotFoundError(), path=FileNotFoundError())
with self.assertRaises(FileNotFoundError):
self.execute(package, 'utf-8.file')
class ZipSetupBase:
ZIP_MODULE = None
@classmethod
def setUpClass(cls):
data_path = pathlib.Path(cls.ZIP_MODULE.__file__)
data_dir = data_path.parent
cls._zip_path = str(data_dir / 'ziptestdata.zip')
sys.path.append(cls._zip_path)
cls.data = importlib.import_module('ziptestdata')
@classmethod
def tearDownClass(cls):
try:
sys.path.remove(cls._zip_path)
except ValueError:
pass
try:
del sys.path_importer_cache[cls._zip_path]
del sys.modules[cls.data.__name__]
except KeyError:
pass
try:
del cls.data
del cls._zip_path
except AttributeError:
pass
def setUp(self):
modules = import_helper.modules_setup()
self.addCleanup(import_helper.modules_cleanup, *modules)
class ZipSetup(ZipSetupBase):
ZIP_MODULE = zipdata01 # type: ignore

View File

@@ -0,0 +1 @@
jaraco

View File

@@ -0,0 +1 @@
jaraco

View File

@@ -0,0 +1 @@
jaraco

View File

@@ -0,0 +1,288 @@
import os
import subprocess
import contextlib
import functools
import tempfile
import shutil
import operator
import warnings
@contextlib.contextmanager
def pushd(dir):
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
... assert os.getcwd() == os.fspath(tmp_path)
>>> assert os.getcwd() != os.fspath(tmp_path)
"""
orig = os.getcwd()
os.chdir(dir)
try:
yield dir
finally:
os.chdir(orig)
@contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
"""
Get a tarball, extract it, change to that directory, yield, then
clean up.
`runner` is the function to invoke commands.
`pushd` is a context manager for changing the directory.
"""
if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and
# then
# use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars()))
try:
getter = 'wget {url} -O -'
extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
cmd = ' | '.join((getter, extract))
runner(cmd.format(compression=infer_compression(url), **vars()))
with pushd(target_dir):
yield target_dir
finally:
runner('rm -Rf {target_dir}'.format(**vars()))
def infer_compression(url):
"""
Given a URL or filename, infer the compression code for tar.
>>> infer_compression('http://foo/bar.tar.gz')
'z'
>>> infer_compression('http://foo/bar.tgz')
'z'
>>> infer_compression('file.bz')
'j'
>>> infer_compression('file.xz')
'J'
"""
# cheat and just assume it's the last two characters
compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J')
# Assume 'z' (gzip) if no match
return mapping.get(compression_indicator, 'z')
@contextlib.contextmanager
def temp_dir(remover=shutil.rmtree):
"""
Create a temporary directory context. Pass a custom remover
to override the removal behavior.
>>> import pathlib
>>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
>>> assert not os.path.exists(the_dir)
"""
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
remover(temp_dir)
@contextlib.contextmanager
def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
"""
Check out the repo indicated by url.
If dest_ctx is supplied, it should be a context manager
to yield the target directory for the check out.
"""
exe = 'git' if 'git' in url else 'hg'
with dest_ctx() as repo_dir:
cmd = [exe, 'clone', url, repo_dir]
if branch:
cmd.extend(['--branch', branch])
devnull = open(os.path.devnull, 'w')
stdout = devnull if quiet else None
subprocess.check_call(cmd, stdout=stdout)
yield repo_dir
@contextlib.contextmanager
def null():
"""
A null context suitable to stand in for a meaningful context.
>>> with null() as value:
... assert value is None
"""
yield
class ExceptionTrap:
"""
A context manager that will catch certain exceptions and provide an
indication they occurred.
>>> with ExceptionTrap() as trap:
... raise Exception()
>>> bool(trap)
True
>>> with ExceptionTrap() as trap:
... pass
>>> bool(trap)
False
>>> with ExceptionTrap(ValueError) as trap:
... raise ValueError("1 + 1 is not 3")
>>> bool(trap)
True
>>> trap.value
ValueError('1 + 1 is not 3')
>>> trap.tb
<traceback object at ...>
>>> with ExceptionTrap(ValueError) as trap:
... raise Exception()
Traceback (most recent call last):
...
Exception
>>> bool(trap)
False
"""
exc_info = None, None, None
def __init__(self, exceptions=(Exception,)):
self.exceptions = exceptions
def __enter__(self):
return self
@property
def type(self):
return self.exc_info[0]
@property
def value(self):
return self.exc_info[1]
@property
def tb(self):
return self.exc_info[2]
def __exit__(self, *exc_info):
type = exc_info[0]
matches = type and issubclass(type, self.exceptions)
if matches:
self.exc_info = exc_info
return matches
def __bool__(self):
return bool(self.type)
def raises(self, func, *, _test=bool):
"""
Wrap func and replace the result with the truth
value of the trap (True if an exception occurred).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> raises = ExceptionTrap(ValueError).raises
Now decorate a function that always fails.
>>> @raises
... def fail():
... raise ValueError('failed')
>>> fail()
True
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with ExceptionTrap(self.exceptions) as trap:
func(*args, **kwargs)
return _test(trap)
return wrapper
def passes(self, func):
"""
Wrap func and replace the result with the truth
value of the trap (True if no exception).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> passes = ExceptionTrap(ValueError).passes
Now decorate a function that always fails.
>>> @passes
... def fail():
... raise ValueError('failed')
>>> fail()
False
"""
return self.raises(func, _test=operator.not_)
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""
A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
class on_interrupt(contextlib.ContextDecorator):
"""
Replace a KeyboardInterrupt with SystemExit(1)
>>> def do_interrupt():
... raise KeyboardInterrupt()
>>> on_interrupt('error')(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 1
>>> on_interrupt('error', code=255)(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 255
>>> on_interrupt('suppress')(do_interrupt)()
>>> with __import__('pytest').raises(KeyboardInterrupt):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action
self.code = code
def __enter__(self):
return self
def __exit__(self, exctype, excinst, exctb):
if exctype is not KeyboardInterrupt or self.action == 'ignore':
return
elif self.action == 'error':
raise SystemExit(self.code) from excinst
return self.action == 'suppress'

View File

@@ -0,0 +1,556 @@
import functools
import time
import inspect
import collections
import types
import itertools
import warnings
import pkg_resources.extern.more_itertools
from typing import Callable, TypeVar
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def compose(*funcs):
"""
Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def method_caller(method_name, *args, **kwargs):
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
def once(func):
"""
Decorate func so it's only ever called the first time.
This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent.
>>> add_three = once(lambda a: a+3)
>>> add_three(3)
6
>>> add_three(9)
6
>>> add_three('12')
6
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not hasattr(wrapper, 'saved_result'):
wrapper.saved_result = func(*args, **kwargs)
return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper
def method_cache(
method: CallableT,
cache_wrapper: Callable[
[CallableT], CallableT
] = functools.lru_cache(), # type: ignore[assignment]
) -> CallableT:
"""
Wrap lru_cache to support storing the cache data in the object instances.
Abstracts the common paradigm where the method explicitly saves an
underscore-prefixed protected property on first call and returns that
subsequently.
>>> class MyClass:
... calls = 0
...
... @method_cache
... def method(self, value):
... self.calls += 1
... return value
>>> a = MyClass()
>>> a.method(3)
3
>>> for x in range(75):
... res = a.method(x)
>>> a.calls
75
Note that the apparent behavior will be exactly like that of lru_cache
except that the cache is stored on each instance, so values in one
instance will not flush values from another, and when an instance is
deleted, so are the cached values for that instance.
>>> b = MyClass()
>>> for x in range(35):
... res = b.method(x)
>>> b.calls
35
>>> a.method(0)
0
>>> a.calls
75
Note that if method had been decorated with ``functools.lru_cache()``,
a.calls would have been 76 (due to the cached value of 0 having been
flushed by the 'b' instance).
Clear the cache with ``.cache_clear()``
>>> a.method.cache_clear()
Same for a method that hasn't yet been called.
>>> c = MyClass()
>>> c.method.cache_clear()
Another cache wrapper may be supplied:
>>> cache = functools.lru_cache(maxsize=2)
>>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache)
>>> a = MyClass()
>>> a.method2()
3
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification.
"""
def wrapper(self: object, *args: object, **kwargs: object) -> object:
# it's the first call, replace the method with a cached, bound method
bound_method: CallableT = types.MethodType( # type: ignore[assignment]
method, self
)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None # type: ignore[attr-defined]
return ( # type: ignore[return-value]
_special_method_cache(method, cache_wrapper) or wrapper
)
def _special_method_cache(method, cache_wrapper):
"""
Because Python treats special methods differently, it's not
possible to use instance attributes to implement the cached
methods.
Instead, install the wrapper method under a different name
and return a simple proxy to that wrapper.
https://github.com/jaraco/jaraco.functools/issues/5
"""
name = method.__name__
special_names = '__getattr__', '__getitem__'
if name not in special_names:
return
wrapper_name = '__cached' + name
def proxy(self, *args, **kwargs):
if wrapper_name not in vars(self):
bound = types.MethodType(method, self)
cache = cache_wrapper(bound)
setattr(self, wrapper_name, cache)
else:
cache = getattr(self, wrapper_name)
return cache(*args, **kwargs)
return proxy
def apply(transform):
"""
Decorate a function with a transform function that is
invoked on results returned from the decorated function.
>>> @apply(reversed)
... def get_numbers(start):
... "doc for get_numbers"
... return range(start, start+3)
>>> list(get_numbers(4))
[6, 5, 4]
>>> get_numbers.__doc__
'doc for get_numbers'
"""
def wrap(func):
return functools.wraps(func)(compose(transform, func))
return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
>>> x
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def invoke(f, *args, **kwargs):
"""
Call a function for its side effect after initialization.
The benefit of using the decorator instead of simply invoking a function
after defining it is that it makes explicit the author's intent for the
function to be called immediately. Whereas if one simply calls the
function immediately, it's less obvious if that was intentional or
incidental. It also avoids repeating the name - the two actions, defining
the function and calling it immediately are modeled separately, but linked
by the decorator construct.
The benefit of having a function construct (opposed to just invoking some
behavior inline) is to serve as a scope in which the behavior occurs. It
avoids polluting the global namespace with local variables, provides an
anchor on which to attach documentation (docstring), keeps the behavior
logically separated (instead of conceptually separated or not separated at
all), and provides potential to re-use the behavior for testing or other
purposes.
This function is named as a pithy way to communicate, "call this function
primarily for its side effect", or "while defining this function, also
take it aside and call it". It exists because there's no Python construct
for "define and call" (nor should there be, as decorators serve this need
just fine). The behavior happens immediately and synchronously.
>>> @invoke
... def func(): print("called")
called
>>> func()
called
Use functools.partial to pass parameters to the initial call
>>> @functools.partial(invoke, name='bingo')
... def func(name): print("called with", name)
called with bingo
"""
f(*args, **kwargs)
return f
def call_aside(*args, **kwargs):
"""
Deprecated name for invoke.
"""
warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning)
return invoke(*args, **kwargs)
class Throttler:
"""
Rate-limit a function (or other callable)
"""
def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler):
func = func.func
self.func = func
self.max_rate = max_rate
self.reset()
def reset(self):
self.last_called = 0
def __call__(self, *args, **kwargs):
self._wait()
return self.func(*args, **kwargs)
def _wait(self):
"ensure at least 1/max_rate seconds from last call"
elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait))
self.last_called = time.time()
def __get__(self, obj, type=None):
return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2):
"""
Return a function that when invoked will invoke func1 without
any parameters (for its side-effect) and then invoke func2
with whatever parameters were passed, returning its result.
"""
def wrapper(*args, **kwargs):
func1()
return func2(*args, **kwargs)
return wrapper
def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
"""
Given a callable func, trap the indicated exceptions
for up to 'retries' times, invoking cleanup on the
exception. On the final attempt, allow any exceptions
to propagate.
"""
attempts = itertools.count() if retries == float('inf') else range(retries)
for attempt in attempts:
try:
return func()
except trap:
cleanup()
return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
It even works on methods:
>>> class Handler:
... def meth(self, arg):
... print(arg)
>>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))()
crystal
"""
sig = inspect.signature(func)
params = sig.parameters.keys()
call_ns = {k: namespace[k] for k in params if k in namespace}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper
def except_(*exceptions, replace=None, use=None):
"""
Replace the indicated exceptions, if raised, with the indicated
literal replacement or evaluated expression (if present).
>>> safe_int = except_(ValueError)(int)
>>> safe_int('five')
>>> safe_int('5')
5
Specify a literal replacement with ``replace``.
>>> safe_int_r = except_(ValueError, replace=0)(int)
>>> safe_int_r('five')
0
Provide an expression to ``use`` to pass through particular parameters.
>>> safe_int_pt = except_(ValueError, use='args[0]')(int)
>>> safe_int_pt('five')
'five'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions:
try:
return eval(use)
except TypeError:
return replace
return wrapper
return decorate

View File

@@ -0,0 +1,2 @@
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst.

View File

@@ -0,0 +1,599 @@
import re
import itertools
import textwrap
import functools
try:
from importlib.resources import files # type: ignore
except ImportError: # pragma: nocover
from pkg_resources.extern.importlib_resources import files # type: ignore
from pkg_resources.extern.jaraco.functools import compose, method_cache
from pkg_resources.extern.jaraco.context import ExceptionTrap
def substitution(old, new):
"""
Return a function that will perform a substitution on a string
"""
return lambda s: s.replace(old, new)
def multi_substitution(*substitutions):
"""
Take a sequence of pairs specifying substitutions, and create
a function that performs those substitutions.
>>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo')
'baz'
"""
substitutions = itertools.starmap(substitution, substitutions)
# compose function applies last function first, so reverse the
# substitutions to get the expected order.
substitutions = reversed(tuple(substitutions))
return compose(*substitutions)
class FoldedCase(str):
"""
A case insensitive string class; behaves just like str
except compares equal when the only variation is case.
>>> s = FoldedCase('hello world')
>>> s == 'Hello World'
True
>>> 'Hello World' == s
True
>>> s != 'Hello World'
False
>>> s.index('O')
4
>>> s.split('O')
['hell', ' w', 'rld']
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
You may test for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use ``in_``:
>>> FoldedCase('hello').in_('Hello World')
True
>>> FoldedCase('hello') > FoldedCase('Hello')
False
"""
def __lt__(self, other):
return self.lower() < other.lower()
def __gt__(self, other):
return self.lower() > other.lower()
def __eq__(self, other):
return self.lower() == other.lower()
def __ne__(self, other):
return self.lower() != other.lower()
def __hash__(self):
return hash(self.lower())
def __contains__(self, other):
return super().lower().__contains__(other.lower())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache lower since it's likely to be called frequently.
@method_cache
def lower(self):
return super().lower()
def index(self, sub):
return self.lower().index(sub.lower())
def split(self, splitter=' ', maxsplit=0):
pattern = re.compile(re.escape(splitter), re.I)
return pattern.split(self, maxsplit)
# Python 3.8 compatibility
_unicode_trap = ExceptionTrap(UnicodeDecodeError)
@_unicode_trap.passes
def is_decodable(value):
r"""
Return True if the supplied value is decodable (using the default
encoding).
>>> is_decodable(b'\xff')
False
>>> is_decodable(b'\x32')
True
"""
value.decode()
def is_binary(value):
r"""
Return True if the value appears to be binary (that is, it's a byte
string and isn't decodable).
>>> is_binary(b'\xff')
True
>>> is_binary('\xff')
False
"""
return isinstance(value, bytes) and not is_decodable(value)
def trim(s):
r"""
Trim something like a docstring to remove the whitespace that
is common due to indentation and formatting.
>>> trim("\n\tfoo = bar\n\t\tbar = baz\n")
'foo = bar\n\tbar = baz'
"""
return textwrap.dedent(s).strip()
def wrap(s):
"""
Wrap lines of text, retaining existing newlines as
paragraph markers.
>>> print(wrap(lorem_ipsum))
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do
eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad
minim veniam, quis nostrud exercitation ullamco laboris nisi ut
aliquip ex ea commodo consequat. Duis aute irure dolor in
reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
culpa qui officia deserunt mollit anim id est laborum.
<BLANKLINE>
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam
varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus
magna felis sollicitudin mauris. Integer in mauris eu nibh euismod
gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis
risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue,
eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas
fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla
a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis,
neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing
sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque
nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus
quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis,
molestie eu, feugiat in, orci. In hac habitasse platea dictumst.
"""
paragraphs = s.splitlines()
wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs)
return '\n\n'.join(wrapped)
def unwrap(s):
r"""
Given a multi-line string, return an unwrapped version.
>>> wrapped = wrap(lorem_ipsum)
>>> wrapped.count('\n')
20
>>> unwrapped = unwrap(wrapped)
>>> unwrapped.count('\n')
1
>>> print(unwrapped)
Lorem ipsum dolor sit amet, consectetur adipiscing ...
Curabitur pretium tincidunt lacus. Nulla gravida orci ...
"""
paragraphs = re.split(r'\n\n+', s)
cleaned = (para.replace('\n', ' ') for para in paragraphs)
return '\n'.join(cleaned)
class Splitter(object):
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
>>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling']
"""
def __init__(self, *args):
self.args = args
def __call__(self, s):
return s.split(*self.args)
def indent(string, prefix=' ' * 4):
"""
>>> indent('foo')
' foo'
"""
return prefix + string
class WordSet(tuple):
"""
Given an identifier, return the words that identifier represents,
whether in camel case, underscore-separated, etc.
>>> WordSet.parse("camelCase")
('camel', 'Case')
>>> WordSet.parse("under_sep")
('under', 'sep')
Acronyms should be retained
>>> WordSet.parse("firstSNL")
('first', 'SNL')
>>> WordSet.parse("you_and_I")
('you', 'and', 'I')
>>> WordSet.parse("A simple test")
('A', 'simple', 'test')
Multiple caps should not interfere with the first cap of another word.
>>> WordSet.parse("myABCClass")
('my', 'ABC', 'Class')
The result is a WordSet, so you can get the form you need.
>>> WordSet.parse("myABCClass").underscore_separated()
'my_ABC_Class'
>>> WordSet.parse('a-command').camel_case()
'ACommand'
>>> WordSet.parse('someIdentifier').lowered().space_separated()
'some identifier'
Slices of the result should return another WordSet.
>>> WordSet.parse('taken-out-of-context')[1:].underscore_separated()
'out_of_context'
>>> WordSet.from_class_name(WordSet()).lowered().space_separated()
'word set'
>>> example = WordSet.parse('figured it out')
>>> example.headless_camel_case()
'figuredItOut'
>>> example.dash_separated()
'figured-it-out'
"""
_pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))')
def capitalized(self):
return WordSet(word.capitalize() for word in self)
def lowered(self):
return WordSet(word.lower() for word in self)
def camel_case(self):
return ''.join(self.capitalized())
def headless_camel_case(self):
words = iter(self)
first = next(words).lower()
new_words = itertools.chain((first,), WordSet(words).camel_case())
return ''.join(new_words)
def underscore_separated(self):
return '_'.join(self)
def dash_separated(self):
return '-'.join(self)
def space_separated(self):
return ' '.join(self)
def trim_right(self, item):
"""
Remove the item from the end of the set.
>>> WordSet.parse('foo bar').trim_right('foo')
('foo', 'bar')
>>> WordSet.parse('foo bar').trim_right('bar')
('foo',)
>>> WordSet.parse('').trim_right('bar')
()
"""
return self[:-1] if self and self[-1] == item else self
def trim_left(self, item):
"""
Remove the item from the beginning of the set.
>>> WordSet.parse('foo bar').trim_left('foo')
('bar',)
>>> WordSet.parse('foo bar').trim_left('bar')
('foo', 'bar')
>>> WordSet.parse('').trim_left('bar')
()
"""
return self[1:] if self and self[0] == item else self
def trim(self, item):
"""
>>> WordSet.parse('foo bar').trim('foo')
('bar',)
"""
return self.trim_left(item).trim_right(item)
def __getitem__(self, item):
result = super(WordSet, self).__getitem__(item)
if isinstance(item, slice):
result = WordSet(result)
return result
@classmethod
def parse(cls, identifier):
matches = cls._pattern.finditer(identifier)
return WordSet(match.group(0) for match in matches)
@classmethod
def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__)
# for backward compatibility
words = WordSet.parse
def simple_html_strip(s):
r"""
Remove HTML from the string `s`.
>>> str(simple_html_strip(''))
''
>>> print(simple_html_strip('A <bold>stormy</bold> day in paradise'))
A stormy day in paradise
>>> print(simple_html_strip('Somebody <!-- do not --> tell the truth.'))
Somebody tell the truth.
>>> print(simple_html_strip('What about<br/>\nmultiple lines?'))
What about
multiple lines?
"""
html_stripper = re.compile('(<!--.*?-->)|(<[^>]*>)|([^<]+)', re.DOTALL)
texts = (match.group(3) or '' for match in html_stripper.finditer(s))
return ''.join(texts)
class SeparatedValues(str):
"""
A string separated by a separator. Overrides __iter__ for getting
the values.
>>> list(SeparatedValues('a,b,c'))
['a', 'b', 'c']
Whitespace is stripped and empty values are discarded.
>>> list(SeparatedValues(' a, b , c, '))
['a', 'b', 'c']
"""
separator = ','
def __iter__(self):
parts = self.split(self.separator)
return filter(None, (part.strip() for part in parts))
class Stripper:
r"""
Given a series of lines, find the common prefix and strip it from them.
>>> lines = [
... 'abcdefg\n',
... 'abc\n',
... 'abcde\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix
'abc'
>>> list(res.lines)
['defg\n', '\n', 'de\n']
If no prefix is common, nothing should be stripped.
>>> lines = [
... 'abcd\n',
... '1234\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix = ''
>>> list(res.lines)
['abcd\n', '1234\n']
"""
def __init__(self, prefix, lines):
self.prefix = prefix
self.lines = map(self, lines)
@classmethod
def strip_prefix(cls, lines):
prefix_lines, lines = itertools.tee(lines)
prefix = functools.reduce(cls.common_prefix, prefix_lines)
return cls(prefix, lines)
def __call__(self, line):
if not self.prefix:
return line
null, prefix, rest = line.partition(self.prefix)
return rest
@staticmethod
def common_prefix(s1, s2):
"""
Return the common prefix of two lines.
"""
index = min(len(s1), len(s2))
while s1[:index] != s2[:index]:
index -= 1
return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest
def normalize_newlines(text):
r"""
Replace alternate newlines with the canonical newline.
>>> normalize_newlines('Lorem Ipsum\u2029')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\r\n')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\x85')
'Lorem Ipsum\n'
"""
newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029']
pattern = '|'.join(newlines)
return re.sub(pattern, '\n', text)
def _nonblank(str):
return str and not str.startswith('#')
@functools.singledispatch
def yield_lines(iterable):
r"""
Yield valid lines of a string or iterable.
>>> list(yield_lines(''))
[]
>>> list(yield_lines(['foo', 'bar']))
['foo', 'bar']
>>> list(yield_lines('foo\nbar'))
['foo', 'bar']
>>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
['foo', 'baz #comment']
>>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
['foo', 'bar', 'baz', 'bing']
"""
return itertools.chain.from_iterable(map(yield_lines, iterable))
@yield_lines.register(str)
def _(text):
return filter(_nonblank, map(str.strip, text.splitlines()))
def drop_comment(line):
"""
Drop comments.
>>> drop_comment('foo # bar')
'foo'
A hash without a space may be in a URL.
>>> drop_comment('http://example.com/foo#bar')
'http://example.com/foo#bar'
"""
return line.partition(' #')[0]
def join_continuation(lines):
r"""
Join lines continued by a trailing backslash.
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
['foobarbaz']
Not sure why, but...
The character preceeding the backslash is also elided.
>>> list(join_continuation(['goo\\', 'dly']))
['godly']
A terrible idea, but...
If no line is available to continue, suppress the lines.
>>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
['foo']
"""
lines = iter(lines)
for item in lines:
while item.endswith('\\'):
try:
item = item[:-2].strip() + next(lines)
except StopIteration:
return
yield item

View File

@@ -0,0 +1,6 @@
"""More routines for operating on iterables, beyond itertools"""
from .more import * # noqa
from .recipes import * # noqa
__version__ = '9.1.0'

View File

@@ -0,0 +1,2 @@
from .more import *
from .recipes import *

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,666 @@
"""Stubs for more_itertools.more"""
from __future__ import annotations
from types import TracebackType
from typing import (
Any,
Callable,
Container,
ContextManager,
Generic,
Hashable,
Iterable,
Iterator,
overload,
Reversible,
Sequence,
Sized,
Type,
TypeVar,
type_check_only,
)
from typing_extensions import Protocol
# Type and type variable definitions
_T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_U = TypeVar('_U')
_V = TypeVar('_V')
_W = TypeVar('_W')
_T_co = TypeVar('_T_co', covariant=True)
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
_Raisable = BaseException | Type[BaseException]
@type_check_only
class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
@type_check_only
class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
@type_check_only
class _SupportsSlicing(Protocol[_T_co]):
def __getitem__(self, __k: slice) -> _T_co: ...
def chunked(
iterable: Iterable[_T], n: int | None, strict: bool = ...
) -> Iterator[list[_T]]: ...
@overload
def first(iterable: Iterable[_T]) -> _T: ...
@overload
def first(iterable: Iterable[_T], default: _U) -> _T | _U: ...
@overload
def last(iterable: Iterable[_T]) -> _T: ...
@overload
def last(iterable: Iterable[_T], default: _U) -> _T | _U: ...
@overload
def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
@overload
def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
class peekable(Generic[_T], Iterator[_T]):
def __init__(self, iterable: Iterable[_T]) -> None: ...
def __iter__(self) -> peekable[_T]: ...
def __bool__(self) -> bool: ...
@overload
def peek(self) -> _T: ...
@overload
def peek(self, default: _U) -> _T | _U: ...
def prepend(self, *items: _T) -> None: ...
def __next__(self) -> _T: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> list[_T]: ...
def consumer(func: _GenFn) -> _GenFn: ...
def ilen(iterable: Iterable[object]) -> int: ...
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
def with_iter(
context_manager: ContextManager[Iterable[_T]],
) -> Iterator[_T]: ...
def one(
iterable: Iterable[_T],
too_short: _Raisable | None = ...,
too_long: _Raisable | None = ...,
) -> _T: ...
def raise_(exception: _Raisable, *args: Any) -> None: ...
def strictly_n(
iterable: Iterable[_T],
n: int,
too_short: _GenFn | None = ...,
too_long: _GenFn | None = ...,
) -> list[_T]: ...
def distinct_permutations(
iterable: Iterable[_T], r: int | None = ...
) -> Iterator[tuple[_T, ...]]: ...
def intersperse(
e: _U, iterable: Iterable[_T], n: int = ...
) -> Iterator[_T | _U]: ...
def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ...
@overload
def windowed(
seq: Iterable[_T], n: int, *, step: int = ...
) -> Iterator[tuple[_T | None, ...]]: ...
@overload
def windowed(
seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
) -> Iterator[tuple[_T | _U, ...]]: ...
def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
def substrings_indexes(
seq: Sequence[_T], reverse: bool = ...
) -> Iterator[tuple[Sequence[_T], int, int]]: ...
class bucket(Generic[_T, _U], Container[_U]):
def __init__(
self,
iterable: Iterable[_T],
key: Callable[[_T], _U],
validator: Callable[[object], object] | None = ...,
) -> None: ...
def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_U]: ...
def __getitem__(self, value: object) -> Iterator[_T]: ...
def spy(
iterable: Iterable[_T], n: int = ...
) -> tuple[list[_T], Iterator[_T]]: ...
def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
def interleave_evenly(
iterables: list[Iterable[_T]], lengths: list[int] | None = ...
) -> Iterator[_T]: ...
def collapse(
iterable: Iterable[Any],
base_type: type | None = ...,
levels: int | None = ...,
) -> Iterator[Any]: ...
@overload
def side_effect(
func: Callable[[_T], object],
iterable: Iterable[_T],
chunk_size: None = ...,
before: Callable[[], object] | None = ...,
after: Callable[[], object] | None = ...,
) -> Iterator[_T]: ...
@overload
def side_effect(
func: Callable[[list[_T]], object],
iterable: Iterable[_T],
chunk_size: int,
before: Callable[[], object] | None = ...,
after: Callable[[], object] | None = ...,
) -> Iterator[_T]: ...
def sliced(
seq: _SupportsSlicing[_T], n: int, strict: bool = ...
) -> Iterator[_T]: ...
def split_at(
iterable: Iterable[_T],
pred: Callable[[_T], object],
maxsplit: int = ...,
keep_separator: bool = ...,
) -> Iterator[list[_T]]: ...
def split_before(
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
) -> Iterator[list[_T]]: ...
def split_after(
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
) -> Iterator[list[_T]]: ...
def split_when(
iterable: Iterable[_T],
pred: Callable[[_T, _T], object],
maxsplit: int = ...,
) -> Iterator[list[_T]]: ...
def split_into(
iterable: Iterable[_T], sizes: Iterable[int | None]
) -> Iterator[list[_T]]: ...
@overload
def padded(
iterable: Iterable[_T],
*,
n: int | None = ...,
next_multiple: bool = ...,
) -> Iterator[_T | None]: ...
@overload
def padded(
iterable: Iterable[_T],
fillvalue: _U,
n: int | None = ...,
next_multiple: bool = ...,
) -> Iterator[_T | _U]: ...
@overload
def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
@overload
def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ...
def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
@overload
def stagger(
iterable: Iterable[_T],
offsets: _SizedIterable[int] = ...,
longest: bool = ...,
) -> Iterator[tuple[_T | None, ...]]: ...
@overload
def stagger(
iterable: Iterable[_T],
offsets: _SizedIterable[int] = ...,
longest: bool = ...,
fillvalue: _U = ...,
) -> Iterator[tuple[_T | _U, ...]]: ...
class UnequalIterablesError(ValueError):
def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
@overload
def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
@overload
def zip_equal(
__iter1: Iterable[_T1], __iter2: Iterable[_T2]
) -> Iterator[tuple[_T1, _T2]]: ...
@overload
def zip_equal(
__iter1: Iterable[_T],
__iter2: Iterable[_T],
__iter3: Iterable[_T],
*iterables: Iterable[_T],
) -> Iterator[tuple[_T, ...]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T1],
*,
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: None = None,
) -> Iterator[tuple[_T1 | None]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T1],
__iter2: Iterable[_T2],
*,
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: None = None,
) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T],
__iter2: Iterable[_T],
__iter3: Iterable[_T],
*iterables: Iterable[_T],
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: None = None,
) -> Iterator[tuple[_T | None, ...]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T1],
*,
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: _U,
) -> Iterator[tuple[_T1 | _U]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T1],
__iter2: Iterable[_T2],
*,
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: _U,
) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
@overload
def zip_offset(
__iter1: Iterable[_T],
__iter2: Iterable[_T],
__iter3: Iterable[_T],
*iterables: Iterable[_T],
offsets: _SizedIterable[int],
longest: bool = ...,
fillvalue: _U,
) -> Iterator[tuple[_T | _U, ...]]: ...
def sort_together(
iterables: Iterable[Iterable[_T]],
key_list: Iterable[int] = ...,
key: Callable[..., Any] | None = ...,
reverse: bool = ...,
) -> list[tuple[_T, ...]]: ...
def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
def always_iterable(
obj: object,
base_type: type | tuple[type | tuple[Any, ...], ...] | None = ...,
) -> Iterator[Any]: ...
def adjacent(
predicate: Callable[[_T], bool],
iterable: Iterable[_T],
distance: int = ...,
) -> Iterator[tuple[bool, _T]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: None = None,
valuefunc: None = None,
reducefunc: None = None,
) -> Iterator[tuple[_T, Iterator[_T]]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: None,
reducefunc: None,
) -> Iterator[tuple[_U, Iterator[_T]]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: None,
valuefunc: Callable[[_T], _V],
reducefunc: None,
) -> Iterable[tuple[_T, Iterable[_V]]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: Callable[[_T], _V],
reducefunc: None,
) -> Iterable[tuple[_U, Iterator[_V]]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: None,
valuefunc: None,
reducefunc: Callable[[Iterator[_T]], _W],
) -> Iterable[tuple[_T, _W]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: None,
reducefunc: Callable[[Iterator[_T]], _W],
) -> Iterable[tuple[_U, _W]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: None,
valuefunc: Callable[[_T], _V],
reducefunc: Callable[[Iterable[_V]], _W],
) -> Iterable[tuple[_T, _W]]: ...
@overload
def groupby_transform(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: Callable[[_T], _V],
reducefunc: Callable[[Iterable[_V]], _W],
) -> Iterable[tuple[_U, _W]]: ...
class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
@overload
def __init__(self, __stop: _T) -> None: ...
@overload
def __init__(self, __start: _T, __stop: _T) -> None: ...
@overload
def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
def __bool__(self) -> bool: ...
def __contains__(self, elem: object) -> bool: ...
def __eq__(self, other: object) -> bool: ...
@overload
def __getitem__(self, key: int) -> _T: ...
@overload
def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
def __hash__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __len__(self) -> int: ...
def __reduce__(
self,
) -> tuple[Type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ...
def __repr__(self) -> str: ...
def __reversed__(self) -> Iterator[_T]: ...
def count(self, value: _T) -> int: ...
def index(self, value: _T) -> int: ... # type: ignore
def count_cycle(
iterable: Iterable[_T], n: int | None = ...
) -> Iterable[tuple[int, _T]]: ...
def mark_ends(
iterable: Iterable[_T],
) -> Iterable[tuple[bool, bool, _T]]: ...
def locate(
iterable: Iterable[object],
pred: Callable[..., Any] = ...,
window_size: int | None = ...,
) -> Iterator[int]: ...
def lstrip(
iterable: Iterable[_T], pred: Callable[[_T], object]
) -> Iterator[_T]: ...
def rstrip(
iterable: Iterable[_T], pred: Callable[[_T], object]
) -> Iterator[_T]: ...
def strip(
iterable: Iterable[_T], pred: Callable[[_T], object]
) -> Iterator[_T]: ...
class islice_extended(Generic[_T], Iterator[_T]):
def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ...
def __iter__(self) -> islice_extended[_T]: ...
def __next__(self) -> _T: ...
def __getitem__(self, index: slice) -> islice_extended[_T]: ...
def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
def consecutive_groups(
iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
) -> Iterator[Iterator[_T]]: ...
@overload
def difference(
iterable: Iterable[_T],
func: Callable[[_T, _T], _U] = ...,
*,
initial: None = ...,
) -> Iterator[_T | _U]: ...
@overload
def difference(
iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
) -> Iterator[_U]: ...
class SequenceView(Generic[_T], Sequence[_T]):
def __init__(self, target: Sequence[_T]) -> None: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> Sequence[_T]: ...
def __len__(self) -> int: ...
class seekable(Generic[_T], Iterator[_T]):
def __init__(
self, iterable: Iterable[_T], maxlen: int | None = ...
) -> None: ...
def __iter__(self) -> seekable[_T]: ...
def __next__(self) -> _T: ...
def __bool__(self) -> bool: ...
@overload
def peek(self) -> _T: ...
@overload
def peek(self, default: _U) -> _T | _U: ...
def elements(self) -> SequenceView[_T]: ...
def seek(self, index: int) -> None: ...
class run_length:
@staticmethod
def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ...
@staticmethod
def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ...
def exactly_n(
iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
) -> bool: ...
def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ...
def make_decorator(
wrapping_func: Callable[..., _U], result_index: int = ...
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
@overload
def map_reduce(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: None = ...,
reducefunc: None = ...,
) -> dict[_U, list[_T]]: ...
@overload
def map_reduce(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: Callable[[_T], _V],
reducefunc: None = ...,
) -> dict[_U, list[_V]]: ...
@overload
def map_reduce(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: None = ...,
reducefunc: Callable[[list[_T]], _W] = ...,
) -> dict[_U, _W]: ...
@overload
def map_reduce(
iterable: Iterable[_T],
keyfunc: Callable[[_T], _U],
valuefunc: Callable[[_T], _V],
reducefunc: Callable[[list[_V]], _W],
) -> dict[_U, _W]: ...
def rlocate(
iterable: Iterable[_T],
pred: Callable[..., object] = ...,
window_size: int | None = ...,
) -> Iterator[int]: ...
def replace(
iterable: Iterable[_T],
pred: Callable[..., object],
substitutes: Iterable[_U],
count: int | None = ...,
window_size: int = ...,
) -> Iterator[_T | _U]: ...
def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
def set_partitions(
iterable: Iterable[_T], k: int | None = ...
) -> Iterator[list[list[_T]]]: ...
class time_limited(Generic[_T], Iterator[_T]):
def __init__(
self, limit_seconds: float, iterable: Iterable[_T]
) -> None: ...
def __iter__(self) -> islice_extended[_T]: ...
def __next__(self) -> _T: ...
@overload
def only(
iterable: Iterable[_T], *, too_long: _Raisable | None = ...
) -> _T | None: ...
@overload
def only(
iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ...
) -> _T | _U: ...
def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
def distinct_combinations(
iterable: Iterable[_T], r: int
) -> Iterator[tuple[_T, ...]]: ...
def filter_except(
validator: Callable[[Any], object],
iterable: Iterable[_T],
*exceptions: Type[BaseException],
) -> Iterator[_T]: ...
def map_except(
function: Callable[[Any], _U],
iterable: Iterable[_T],
*exceptions: Type[BaseException],
) -> Iterator[_U]: ...
def map_if(
iterable: Iterable[Any],
pred: Callable[[Any], bool],
func: Callable[[Any], Any],
func_else: Callable[[Any], Any] | None = ...,
) -> Iterator[Any]: ...
def sample(
iterable: Iterable[_T],
k: int,
weights: Iterable[float] | None = ...,
) -> list[_T]: ...
def is_sorted(
iterable: Iterable[_T],
key: Callable[[_T], _U] | None = ...,
reverse: bool = False,
strict: bool = False,
) -> bool: ...
class AbortThread(BaseException):
pass
class callback_iter(Generic[_T], Iterator[_T]):
def __init__(
self,
func: Callable[..., Any],
callback_kwd: str = ...,
wait_seconds: float = ...,
) -> None: ...
def __enter__(self) -> callback_iter[_T]: ...
def __exit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None: ...
def __iter__(self) -> callback_iter[_T]: ...
def __next__(self) -> _T: ...
def _reader(self) -> Iterator[_T]: ...
@property
def done(self) -> bool: ...
@property
def result(self) -> Any: ...
def windowed_complete(
iterable: Iterable[_T], n: int
) -> Iterator[tuple[_T, ...]]: ...
def all_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> bool: ...
def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
def nth_permutation(
iterable: Iterable[_T], r: int, index: int
) -> tuple[_T, ...]: ...
def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ...
def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
def combination_index(
element: Iterable[_T], iterable: Iterable[_T]
) -> int: ...
def permutation_index(
element: Iterable[_T], iterable: Iterable[_T]
) -> int: ...
def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
class countable(Generic[_T], Iterator[_T]):
def __init__(self, iterable: Iterable[_T]) -> None: ...
def __iter__(self) -> countable[_T]: ...
def __next__(self) -> _T: ...
def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
def zip_broadcast(
*objects: _T | Iterable[_T],
scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
def unique_in_window(
iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ...
def duplicates_everseen(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ...
def duplicates_justseen(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ...
class _SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
@overload
def minmax(
iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
@overload
def minmax(
iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
) -> tuple[_T, _T]: ...
@overload
def minmax(
iterable_or_value: Iterable[_SupportsLessThanT],
*,
key: None = None,
default: _U,
) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
@overload
def minmax(
iterable_or_value: Iterable[_T],
*,
key: Callable[[_T], _SupportsLessThan],
default: _U,
) -> _U | tuple[_T, _T]: ...
@overload
def minmax(
iterable_or_value: _SupportsLessThanT,
__other: _SupportsLessThanT,
*others: _SupportsLessThanT,
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
@overload
def minmax(
iterable_or_value: _T,
__other: _T,
*others: _T,
key: Callable[[_T], _SupportsLessThan],
) -> tuple[_T, _T]: ...
def longest_common_prefix(
iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[object]) -> bool: ...
def constrained_batches(
iterable: Iterable[object],
max_size: int,
max_count: int | None = ...,
get_len: Callable[[_T], object] = ...,
strict: bool = ...,
) -> Iterator[tuple[_T]]: ...
def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...

View File

@@ -0,0 +1,930 @@
"""Imported from the recipes section of the itertools documentation.
All functions taken from the recipes section of the itertools library docs
[1]_.
Some backward-compatible usability improvements have been made.
.. [1] http://docs.python.org/library/itertools.html#recipes
"""
import math
import operator
import warnings
from collections import deque
from collections.abc import Sized
from functools import reduce
from itertools import (
chain,
combinations,
compress,
count,
cycle,
groupby,
islice,
product,
repeat,
starmap,
tee,
zip_longest,
)
from random import randrange, sample, choice
from sys import hexversion
__all__ = [
'all_equal',
'batched',
'before_and_after',
'consume',
'convolve',
'dotproduct',
'first_true',
'factor',
'flatten',
'grouper',
'iter_except',
'iter_index',
'matmul',
'ncycles',
'nth',
'nth_combination',
'padnone',
'pad_none',
'pairwise',
'partition',
'polynomial_from_roots',
'powerset',
'prepend',
'quantify',
'random_combination_with_replacement',
'random_combination',
'random_permutation',
'random_product',
'repeatfunc',
'roundrobin',
'sieve',
'sliding_window',
'subslices',
'tabulate',
'tail',
'take',
'transpose',
'triplewise',
'unique_everseen',
'unique_justseen',
]
_marker = object()
def take(n, iterable):
"""Return first *n* items of the iterable as a list.
>>> take(3, range(10))
[0, 1, 2]
If there are fewer than *n* items in the iterable, all of them are
returned.
>>> take(10, range(3))
[0, 1, 2]
"""
return list(islice(iterable, n))
def tabulate(function, start=0):
"""Return an iterator over the results of ``func(start)``,
``func(start + 1)``, ``func(start + 2)``...
*func* should be a function that accepts one integer argument.
If *start* is not specified it defaults to 0. It will be incremented each
time the iterator is advanced.
>>> square = lambda x: x ** 2
>>> iterator = tabulate(square, -3)
>>> take(4, iterator)
[9, 4, 1, 0]
"""
return map(function, count(start))
def tail(n, iterable):
"""Return an iterator over the last *n* items of *iterable*.
>>> t = tail(3, 'ABCDEFG')
>>> list(t)
['E', 'F', 'G']
"""
# If the given iterable has a length, then we can use islice to get its
# final elements. Note that if the iterable is not actually Iterable,
# either islice or deque will throw a TypeError. This is why we don't
# check if it is Iterable.
if isinstance(iterable, Sized):
yield from islice(iterable, max(0, len(iterable) - n), None)
else:
yield from iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"""Advance *iterable* by *n* steps. If *n* is ``None``, consume it
entirely.
Efficiently exhausts an iterator without returning values. Defaults to
consuming the whole iterator, but an optional second argument may be
provided to limit consumption.
>>> i = (x for x in range(10))
>>> next(i)
0
>>> consume(i, 3)
>>> next(i)
4
>>> consume(i)
>>> next(i)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration
If the iterator has fewer items remaining than the provided limit, the
whole iterator will be consumed.
>>> i = (x for x in range(3))
>>> consume(i, 5)
>>> next(i)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration
"""
# Use functions that consume iterators at C speed.
if n is None:
# feed the entire iterator into a zero-length deque
deque(iterator, maxlen=0)
else:
# advance to the empty slice starting at position n
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"""Returns the nth item or a default value.
>>> l = range(10)
>>> nth(l, 3)
3
>>> nth(l, 20, "zebra")
'zebra'
"""
return next(islice(iterable, n, None), default)
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
>>> all_equal('aaaa')
True
>>> all_equal('aaab')
False
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def quantify(iterable, pred=bool):
"""Return the how many times the predicate is true.
>>> quantify([True, False, True])
2
"""
return sum(map(pred, iterable))
def pad_none(iterable):
"""Returns the sequence of elements and then returns ``None`` indefinitely.
>>> take(5, pad_none(range(3)))
[0, 1, 2, None, None]
Useful for emulating the behavior of the built-in :func:`map` function.
See also :func:`padded`.
"""
return chain(iterable, repeat(None))
padnone = pad_none
def ncycles(iterable, n):
"""Returns the sequence elements *n* times
>>> list(ncycles(["a", "b"], 3))
['a', 'b', 'a', 'b', 'a', 'b']
"""
return chain.from_iterable(repeat(tuple(iterable), n))
def dotproduct(vec1, vec2):
"""Returns the dot product of the two iterables.
>>> dotproduct([10, 10], [20, 20])
400
"""
return sum(map(operator.mul, vec1, vec2))
def flatten(listOfLists):
"""Return an iterator flattening one level of nesting in a list of lists.
>>> list(flatten([[0, 1], [2, 3]]))
[0, 1, 2, 3]
See also :func:`collapse`, which can flatten multiple levels of nesting.
"""
return chain.from_iterable(listOfLists)
def repeatfunc(func, times=None, *args):
"""Call *func* with *args* repeatedly, returning an iterable over the
results.
If *times* is specified, the iterable will terminate after that many
repetitions:
>>> from operator import add
>>> times = 4
>>> args = 3, 5
>>> list(repeatfunc(add, times, *args))
[8, 8, 8, 8]
If *times* is ``None`` the iterable will not terminate:
>>> from random import randrange
>>> times = None
>>> args = 1, 11
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
[2, 4, 8, 1, 8, 4]
"""
if times is None:
return starmap(func, repeat(args))
return starmap(func, repeat(args, times))
def _pairwise(iterable):
"""Returns an iterator of paired items, overlapping, from the original
>>> take(4, pairwise(count()))
[(0, 1), (1, 2), (2, 3), (3, 4)]
On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
"""
a, b = tee(iterable)
next(b, None)
yield from zip(a, b)
try:
from itertools import pairwise as itertools_pairwise
except ImportError:
pairwise = _pairwise
else:
def pairwise(iterable):
yield from itertools_pairwise(iterable)
pairwise.__doc__ = _pairwise.__doc__
class UnequalIterablesError(ValueError):
def __init__(self, details=None):
msg = 'Iterables have different lengths'
if details is not None:
msg += (': index 0 has length {}; index {} has length {}').format(
*details
)
super().__init__(msg)
def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo
def _zip_equal(*iterables):
# Check whether the iterables are all the same size.
try:
first_size = len(iterables[0])
for i, it in enumerate(iterables[1:], 1):
size = len(it)
if size != first_size:
break
else:
# If we didn't break out, we can use the built-in zip.
return zip(*iterables)
# If we did break out, there was a mismatch.
raise UnequalIterablesError(details=(first_size, i, size))
# If any one of the iterables didn't have a length, start reading
# them until one runs out.
except TypeError:
return _zip_equal_generator(iterables)
def grouper(iterable, n, incomplete='fill', fillvalue=None):
"""Group elements from *iterable* into fixed-length groups of length *n*.
>>> list(grouper('ABCDEF', 3))
[('A', 'B', 'C'), ('D', 'E', 'F')]
The keyword arguments *incomplete* and *fillvalue* control what happens for
iterables whose length is not a multiple of *n*.
When *incomplete* is `'fill'`, the last group will contain instances of
*fillvalue*.
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
When *incomplete* is `'ignore'`, the last group will not be emitted.
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F')]
When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
>>> it = grouper('ABCDEFG', 3, incomplete='strict')
>>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
UnequalIterablesError
"""
args = [iter(iterable)] * n
if incomplete == 'fill':
return zip_longest(*args, fillvalue=fillvalue)
if incomplete == 'strict':
return _zip_equal(*args)
if incomplete == 'ignore':
return zip(*args)
else:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"""Yields an item from each iterable, alternating between them.
>>> list(roundrobin('ABC', 'D', 'EF'))
['A', 'D', 'E', 'B', 'F', 'C']
This function produces the same output as :func:`interleave_longest`, but
may perform better for some inputs (in particular when the number of
iterables is small).
"""
# Recipe credited to George Sakkis
pending = len(iterables)
nexts = cycle(iter(it).__next__ for it in iterables)
while pending:
try:
for next in nexts:
yield next()
except StopIteration:
pending -= 1
nexts = cycle(islice(nexts, pending))
def partition(pred, iterable):
"""
Returns a 2-tuple of iterables derived from the input iterable.
The first yields the items that have ``pred(item) == False``.
The second yields the items that have ``pred(item) == True``.
>>> is_odd = lambda x: x % 2 != 0
>>> iterable = range(10)
>>> even_items, odd_items = partition(is_odd, iterable)
>>> list(even_items), list(odd_items)
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
If *pred* is None, :func:`bool` is used.
>>> iterable = [0, 1, False, True, '', ' ']
>>> false_items, true_items = partition(None, iterable)
>>> list(false_items), list(true_items)
([0, False, ''], [1, True, ' '])
"""
if pred is None:
pred = bool
evaluations = ((pred(x), x) for x in iterable)
t1, t2 = tee(evaluations)
return (
(x for (cond, x) in t1 if not cond),
(x for (cond, x) in t2 if cond),
)
def powerset(iterable):
"""Yields all possible subsets of the iterable.
>>> list(powerset([1, 2, 3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
:func:`powerset` will operate on iterables that aren't :class:`set`
instances, so repeated elements in the input will produce repeated elements
in the output. Use :func:`unique_everseen` on the input to avoid generating
duplicates:
>>> seq = [1, 1, 0]
>>> list(powerset(seq))
[(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
>>> from more_itertools import unique_everseen
>>> list(powerset(unique_everseen(seq)))
[(), (1,), (0,), (1, 0)]
"""
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
def unique_everseen(iterable, key=None):
"""
Yield unique elements, preserving order.
>>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D']
>>> list(unique_everseen('ABBCcAD', str.lower))
['A', 'B', 'C', 'D']
Sequences with a mix of hashable and unhashable items can be used.
The function will be slower (i.e., `O(n^2)`) for unhashable items.
Remember that ``list`` objects are unhashable - you can use the *key*
parameter to transform the list to a tuple (which is hashable) to
avoid a slowdown.
>>> iterable = ([1, 2], [2, 3], [1, 2])
>>> list(unique_everseen(iterable)) # Slow
[[1, 2], [2, 3]]
>>> list(unique_everseen(iterable, key=tuple)) # Faster
[[1, 2], [2, 3]]
Similary, you may want to convert unhashable ``set`` objects with
``key=frozenset``. For ``dict`` objects,
``key=lambda x: frozenset(x.items())`` can be used.
"""
seenset = set()
seenset_add = seenset.add
seenlist = []
seenlist_add = seenlist.append
use_key = key is not None
for element in iterable:
k = key(element) if use_key else element
try:
if k not in seenset:
seenset_add(k)
yield element
except TypeError:
if k not in seenlist:
seenlist_add(k)
yield element
def unique_justseen(iterable, key=None):
"""Yields elements in order, ignoring serial duplicates
>>> list(unique_justseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D', 'A', 'B']
>>> list(unique_justseen('ABBCcAD', str.lower))
['A', 'B', 'C', 'A', 'D']
"""
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
def iter_except(func, exception, first=None):
"""Yields results from a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface.
Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
to end the loop.
>>> l = [0, 1, 2]
>>> list(iter_except(l.pop, IndexError))
[2, 1, 0]
Multiple exceptions can be specified as a stopping condition:
>>> l = [1, 2, 3, '...', 4, 5, 6]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[7, 6, 5]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[4, 3, 2]
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
[]
"""
try:
if first is not None:
yield first()
while 1:
yield func()
except exception:
pass
def first_true(iterable, default=None, pred=None):
"""
Returns the first true value in the iterable.
If no true value is found, returns *default*
If *pred* is not None, returns the first item for which
``pred(item) == True`` .
>>> first_true(range(10))
1
>>> first_true(range(10), pred=lambda x: x > 5)
6
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
'missing'
"""
return next(filter(pred, iterable), default)
def random_product(*args, repeat=1):
"""Draw an item at random from each of the input iterables.
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
('c', 3, 'Z')
If *repeat* is provided as a keyword argument, that many items will be
drawn from each iterable.
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
('a', 2, 'd', 3)
This equivalent to taking a random selection from
``itertools.product(*args, **kwarg)``.
"""
pools = [tuple(pool) for pool in args] * repeat
return tuple(choice(pool) for pool in pools)
def random_permutation(iterable, r=None):
"""Return a random *r* length permutation of the elements in *iterable*.
If *r* is not specified or is ``None``, then *r* defaults to the length of
*iterable*.
>>> random_permutation(range(5)) # doctest:+SKIP
(3, 4, 0, 1, 2)
This equivalent to taking a random selection from
``itertools.permutations(iterable, r)``.
"""
pool = tuple(iterable)
r = len(pool) if r is None else r
return tuple(sample(pool, r))
def random_combination(iterable, r):
"""Return a random *r* length subsequence of the elements in *iterable*.
>>> random_combination(range(5), 3) # doctest:+SKIP
(2, 3, 4)
This equivalent to taking a random selection from
``itertools.combinations(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(sample(range(n), r))
return tuple(pool[i] for i in indices)
def random_combination_with_replacement(iterable, r):
"""Return a random *r* length subsequence of elements in *iterable*,
allowing individual elements to be repeated.
>>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
(0, 0, 1, 2, 2)
This equivalent to taking a random selection from
``itertools.combinations_with_replacement(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(randrange(n) for i in range(r))
return tuple(pool[i] for i in indices)
def nth_combination(iterable, r, index):
"""Equivalent to ``list(combinations(iterable, r))[index]``.
The subsequences of *iterable* that are of length *r* can be ordered
lexicographically. :func:`nth_combination` computes the subsequence at
sort position *index* directly, without computing the previous
subsequences.
>>> nth_combination(range(5), 3, 5)
(0, 3, 4)
``ValueError`` will be raised If *r* is negative or greater than the length
of *iterable*.
``IndexError`` will be raised if the given *index* is invalid.
"""
pool = tuple(iterable)
n = len(pool)
if (r < 0) or (r > n):
raise ValueError
c = 1
k = min(r, n - r)
for i in range(1, k + 1):
c = c * (n - k + i) // i
if index < 0:
index += c
if (index < 0) or (index >= c):
raise IndexError
result = []
while r:
c, n, r = c * r // n, n - 1, r - 1
while index >= c:
index -= c
c, n = c * (n - r) // n, n - 1
result.append(pool[-1 - n])
return tuple(result)
def prepend(value, iterator):
"""Yield *value*, followed by the elements in *iterator*.
>>> value = '0'
>>> iterator = ['1', '2', '3']
>>> list(prepend(value, iterator))
['0', '1', '2', '3']
To prepend multiple values, see :func:`itertools.chain`
or :func:`value_chain`.
"""
return chain([value], iterator)
def convolve(signal, kernel):
"""Convolve the iterable *signal* with the iterable *kernel*.
>>> signal = (1, 2, 3, 4, 5)
>>> kernel = [3, 2, 1]
>>> list(convolve(signal, kernel))
[3, 8, 14, 20, 26, 14, 5]
Note: the input arguments are not interchangeable, as the *kernel*
is immediately consumed and stored.
"""
kernel = tuple(kernel)[::-1]
n = len(kernel)
window = deque([0], maxlen=n) * n
for x in chain(signal, repeat(0, n - 1)):
window.append(x)
yield sum(map(operator.mul, kernel, window))
def before_and_after(predicate, it):
"""A variant of :func:`takewhile` that allows complete access to the
remainder of the iterator.
>>> it = iter('ABCdEfGhI')
>>> all_upper, remainder = before_and_after(str.isupper, it)
>>> ''.join(all_upper)
'ABC'
>>> ''.join(remainder) # takewhile() would lose the 'd'
'dEfGhI'
Note that the first iterator must be fully consumed before the second
iterator can generate valid results.
"""
it = iter(it)
transition = []
def true_iterator():
for elem in it:
if predicate(elem):
yield elem
else:
transition.append(elem)
return
# Note: this is different from itertools recipes to allow nesting
# before_and_after remainders into before_and_after again. See tests
# for an example.
remainder_iterator = chain(transition, it)
return true_iterator(), remainder_iterator
def triplewise(iterable):
"""Return overlapping triplets from *iterable*.
>>> list(triplewise('ABCDE'))
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
"""
for (a, _), (b, c) in pairwise(pairwise(iterable)):
yield a, b, c
def sliding_window(iterable, n):
"""Return a sliding window of width *n* over *iterable*.
>>> list(sliding_window(range(6), 4))
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
If *iterable* has fewer than *n* items, then nothing is yielded:
>>> list(sliding_window(range(3), 4))
[]
For a variant with more features, see :func:`windowed`.
"""
it = iter(iterable)
window = deque(islice(it, n), maxlen=n)
if len(window) == n:
yield tuple(window)
for x in it:
window.append(x)
yield tuple(window)
def subslices(iterable):
"""Return all contiguous non-empty subslices of *iterable*.
>>> list(subslices('ABC'))
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
This is similar to :func:`substrings`, but emits items in a different
order.
"""
seq = list(iterable)
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(operator.getitem, repeat(seq), slices)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
[1, -4, -17, 60]
"""
# Use math.prod for Python 3.8+,
prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1))
roots = list(map(operator.neg, roots))
return [
sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1)
]
def iter_index(iterable, value, start=0):
"""Yield the index of each place in *iterable* that *value* occurs,
beginning with index *start*.
See :func:`locate` for a more general means of finding the indexes
associated with particular values.
>>> list(iter_index('AABCADEAF', 'A'))
[0, 1, 4, 7]
"""
try:
seq_index = iterable.index
except AttributeError:
# Slow path for general iterables
it = islice(iterable, start, None)
for i, element in enumerate(it, start):
if element is value or element == value:
yield i
else:
# Fast path for sequences
i = start - 1
try:
while True:
i = seq_index(value, i + 1)
yield i
except ValueError:
pass
def sieve(n):
"""Yield the primes less than n.
>>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
"""
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
data = bytearray((0, 1)) * (n // 2)
data[:3] = 0, 0, 0
limit = isqrt(n) + 1
for p in compress(range(limit), data):
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
data[2] = 1
return iter_index(data, 1) if n > 2 else iter([])
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
>>> list(batched('ABCDEFG', 3))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
This recipe is from the ``itertools`` docs. This library also provides
:func:`chunked`, which has a different implementation.
"""
if hexversion >= 0x30C00A0: # Python 3.12.0a0
warnings.warn(
(
'batched will be removed in a future version of '
'more-itertools. Use the standard library '
'itertools.batched function instead'
),
DeprecationWarning,
)
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
def transpose(it):
"""Swap the rows and columns of the input.
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)]
The caller should ensure that the dimensions of the input are compatible.
"""
# TODO: when 3.9 goes end-of-life, add stric=True to this.
return zip(*it)
def matmul(m1, m2):
"""Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[[49, 80], [41, 60]]
The caller should ensure that the dimensions of the input matrices are
compatible with each other.
"""
n = len(m2[0])
return batched(starmap(dotproduct, product(m1, transpose(m2))), n)
def factor(n):
"""Yield the prime factors of n.
>>> list(factor(360))
[2, 2, 2, 3, 3, 5]
"""
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
for prime in sieve(isqrt(n) + 1):
while True:
quotient, remainder = divmod(n, prime)
if remainder:
break
yield prime
n = quotient
if n == 1:
return
if n >= 2:
yield n

Some files were not shown because too many files have changed in this diff Show More